init commit

This commit is contained in:
2024-10-15 10:49:25 +08:00
parent 96bd52bfcc
commit 8dd7505095
15 changed files with 1105 additions and 0 deletions

248
test.py Normal file
View File

@@ -0,0 +1,248 @@
import os
import re
import sys
import json
import warnings
from chat import QueryChatGPT, llm_configured, load_config
from typing import Optional, Dict, List, Tuple
DIR = os.path.dirname(os.path.abspath(__file__))
PROMPT_PATH = os.path.join(DIR, 'prompt.json')
DANGER_FUNC = os.path.join(DIR, 'danger_func.json')
IN_DIR = os.path.join(DIR, 'input')
OUT_DIR = os.path.join(DIR, 'output')
def get_prompt(name: str, _type: str, prompt_path: str = PROMPT_PATH) -> Optional[Dict[str, str]]:
"""
Access the prompt
Args:
name: the name of the prompt
_type: the type of the prompt
prompt_path: the path of the prompt file
Returns:
a dict containing two keys: 'role' and 'content'
"""
prompts = None
with open(prompt_path, 'r') as f:
prompts = json.load(f)
assert(prompts)
for _p in prompts:
if _p['name'] == name and _p['type'] == _type:
return _p['prompt']
return None
def read_decompile_code(file_path: str) -> Optional[str]:
if not os.path.exists(file_path):
warnings.warn("Fail to find {file_path}!".format(file_path, file_path))
sys.exit(1)
with open(file_path, 'r') as r:
code_data = r.read()
return code_data
data = {
"name": "Henry",
"age": 30,
"fuzzing": True,
"languages": ["C", "Python", "Assembly"]
}
def handle_dprintf(file_name: str, code: str, patch_dprintf_file: str = 'patch_dprintf.json'):
if 'dprintf' not in code:
return
print("begin to test dprintf in {file_name}.".format(file_name = file_name))
# generate output file
output_dir = OUT_DIR + '/' + file_name[:-2]
os.makedirs(output_dir, exist_ok=True)
output_file = output_dir + '/' + patch_dprintf_file
print("The dprintf info store into the " + output_file)
# get prompt from prompt.json
prompt = get_prompt('dprintf', 'attack')
assert (prompt)
q = QueryChatGPT()
response = q.query(prompt['content'].format(code = code))
# judge whether the program exists dprintf format string vulnerability
print("response info : " + response[:4].lower())
if 'yes' not in response[:4].lower():
return
# store the relevant information into INPUT_DIR + "patch_dprintf.json"
data = {
"file_path": file_name,
"vul_info": response[4:] # set response[4:] to skip string "yes, "
}
if not os.path.exists(output_file):
with open(output_file, 'w') as w:
json.dump([], w, indent=4)
with open(output_file, 'r') as r:
log = json.load(r)
assert (isinstance(log, list))
# insert relevant vulnerability info into json file
log.append(data)
with open(output_file, 'w') as w:
json.dump(log, w, indent=4)
# invoke patch api here
# patch_dprintf(file_path)
def handle_recv(file_name: str, code: str, patch_recv_file: str = 'patch_recv.json'):
if 'recv' not in code:
return
print("begin to test recv in {file_name}.".format(file_name = file_name))
# generate output file
output_dir = OUT_DIR + '/' + file_name[:-2]
os.makedirs(output_dir, exist_ok=True)
output_file = output_dir + '/' + patch_recv_file
print("The recv info store into the " + output_file)
# get prompt from prompt.json
prompt = get_prompt('recv', 'attack')
assert (prompt)
q = QueryChatGPT()
response = q.query(prompt['content'].format(code = code))
print(response)
# judge whether the program exists buffer overflow vulnerability due to recv func
print("response info : " + response[:4].lower())
if 'yes' not in response[:4].lower():
return
# store the relevant information into INPUT_DIR + "patch_recv.json"
data = {
"file_name": file_name,
"vul_info": response[4:] # set response[4:] to skip string "yes, "
}
# determine the specific size to fix recv func
# get patch prompt for recv
prompt = get_prompt('recv', 'patch')
assert (prompt)
q = QueryChatGPT()
response = q.query(prompt['content'].format(code = code))
print(response)
# record modified size
match = re.search(r'size=(\d+)', response)
data['fix_size'] = int(match.group(1))
match = re.search(r',\s*(.*)', response)
data['patch_info'] = match.group(1)
if not os.path.exists(output_file):
with open(output_file, 'w') as w:
json.dump([], w, indent=4)
with open(output_file, 'r') as r:
log = json.load(r)
assert (isinstance(log, list))
# insert relevant vulnerability info into json file
log.append(data)
with open(output_file, 'w') as w:
json.dump(log, w, indent=4)
# invoke patch api here
# patch_recv(file_path)
def handle_strcpy(file_name: str, code: str, patch_strcpy_file: str = 'patch_strcpy.json'):
if 'strcpy' not in code:
return
print("begin to test strcpy in {file_name}.".format(file_name = file_name))
# generate output file
output_dir = OUT_DIR + '/' + file_name[:-2]
os.makedirs(output_dir, exist_ok=True)
output_file = output_dir + '/' + patch_strcpy_file
print("The recv info store into the " + output_file)
# get prompt from prompt.json
prompt = get_prompt('strcpy', 'attack')
assert (prompt)
q = QueryChatGPT()
response = q.query(prompt['content'].format(code = code))
print(response)
# judge whether the program exists buffer overflow vulnerability due to recv func
print("response info : " + response[:4].lower())
if 'yes' not in response[:4].lower():
return
# store the relevant information into INPUT_DIR + "patch_recv.json"
data = {
"file_name": file_name,
"vul_info": response[4:] # set response[4:] to skip string "yes, "
}
# determine the specific size to fix recv func
# get patch prompt for recv
prompt = get_prompt('strcpy', 'patch')
assert (prompt)
q = QueryChatGPT()
response = q.query(prompt['content'].format(code = code))
print(response)
# record modified size
# match = re.search(r'size=(\d+)', response)
# data['fix_size'] = int(match.group(1))
# match = re.search(r',\s*(.*)', response)
# data['patch_info'] = match.group(1)
if not os.path.exists(output_file):
with open(output_file, 'w') as w:
json.dump([], w, indent=4)
with open(output_file, 'r') as r:
log = json.load(r)
assert (isinstance(log, list))
# insert relevant vulnerability info into json file
log.append(data)
with open(output_file, 'w') as w:
json.dump(log, w, indent=4)
# invoke patch api here
# patch_recv(file_path)
def exp():
print("Trying to test function normally")
test_file = "edit_extract.c"
# test_file = "recv_extract.c"
# test_file = "dprintf_extract.c"
# read code which will be used to analyze by LLM
code = read_decompile_code(IN_DIR + "/" + test_file)
# read danger_func.json to determine the scope of the func checked
with open(DANGER_FUNC, 'r') as f:
d_func = json.load(f)
assert(d_func)
for _d in d_func:
if _d['name'] == 'dprintf':
handle_dprintf(test_file, code)
elif _d['name'] == 'recv':
handle_recv(test_file, code)
elif _d['name'] == 'strcpy':
handle_strcpy(test_file, code)
if __name__ == '__main__':
if not llm_configured():
print('please complete llm access setup first...')
exit()
exp()