import os import re import sys import json import warnings from chat import QueryChatGPT, llm_configured, load_config from typing import Optional, Dict, List, Tuple DIR = os.path.dirname(os.path.abspath(__file__)) PROMPT_PATH = os.path.join(DIR, 'prompt.json') DANGER_FUNC = os.path.join(DIR, 'danger_func.json') IN_DIR = os.path.join(DIR, 'input') OUT_DIR = os.path.join(DIR, 'output') def get_prompt(name: str, _type: str, prompt_path: str = PROMPT_PATH) -> Optional[Dict[str, str]]: """ Access the prompt Args: name: the name of the prompt _type: the type of the prompt prompt_path: the path of the prompt file Returns: a dict containing two keys: 'role' and 'content' """ prompts = None with open(prompt_path, 'r') as f: prompts = json.load(f) assert(prompts) for _p in prompts: if _p['name'] == name and _p['type'] == _type: return _p['prompt'] return None def read_decompile_code(file_path: str) -> Optional[str]: if not os.path.exists(file_path): warnings.warn("Fail to find {file_path}!".format(file_path, file_path)) sys.exit(1) with open(file_path, 'r') as r: code_data = r.read() return code_data data = { "name": "Henry", "age": 30, "fuzzing": True, "languages": ["C", "Python", "Assembly"] } def handle_dprintf(file_name: str, code: str, patch_dprintf_file: str = 'patch_dprintf.json'): if 'dprintf' not in code: return print("begin to test dprintf in {file_name}.".format(file_name = file_name)) # generate output file output_dir = OUT_DIR + '/' + file_name[:-2] os.makedirs(output_dir, exist_ok=True) output_file = output_dir + '/' + patch_dprintf_file print("The dprintf info store into the " + output_file) # get prompt from prompt.json prompt = get_prompt('dprintf', 'attack') assert (prompt) q = QueryChatGPT() response = q.query(prompt['content'].format(code = code)) # judge whether the program exists dprintf format string vulnerability print("response info : " + response[:4].lower()) if 'yes' not in response[:4].lower(): return # store the relevant information into INPUT_DIR + "patch_dprintf.json" data = { "file_path": file_name, "vul_info": response[4:] # set response[4:] to skip string "yes, " } if not os.path.exists(output_file): with open(output_file, 'w') as w: json.dump([], w, indent=4) with open(output_file, 'r') as r: log = json.load(r) assert (isinstance(log, list)) # insert relevant vulnerability info into json file log.append(data) with open(output_file, 'w') as w: json.dump(log, w, indent=4) # invoke patch api here # patch_dprintf(file_path) def handle_recv(file_name: str, code: str, patch_recv_file: str = 'patch_recv.json'): if 'recv' not in code: return print("begin to test recv in {file_name}.".format(file_name = file_name)) # generate output file output_dir = OUT_DIR + '/' + file_name[:-2] os.makedirs(output_dir, exist_ok=True) output_file = output_dir + '/' + patch_recv_file print("The recv info store into the " + output_file) # get prompt from prompt.json prompt = get_prompt('recv', 'attack') assert (prompt) q = QueryChatGPT() response = q.query(prompt['content'].format(code = code)) print(response) # judge whether the program exists buffer overflow vulnerability due to recv func print("response info : " + response[:4].lower()) if 'yes' not in response[:4].lower(): return # store the relevant information into INPUT_DIR + "patch_recv.json" data = { "file_name": file_name, "vul_info": response[4:] # set response[4:] to skip string "yes, " } # determine the specific size to fix recv func # get patch prompt for recv prompt = get_prompt('recv', 'patch') assert (prompt) q = QueryChatGPT() response = q.query(prompt['content'].format(code = code)) print(response) # record modified size match = re.search(r'size=(\d+)', response) data['fix_size'] = int(match.group(1)) match = re.search(r',\s*(.*)', response) data['patch_info'] = match.group(1) if not os.path.exists(output_file): with open(output_file, 'w') as w: json.dump([], w, indent=4) with open(output_file, 'r') as r: log = json.load(r) assert (isinstance(log, list)) # insert relevant vulnerability info into json file log.append(data) with open(output_file, 'w') as w: json.dump(log, w, indent=4) # invoke patch api here # patch_recv(file_path) def handle_strcpy(file_name: str, code: str, patch_strcpy_file: str = 'patch_strcpy.json'): if 'strcpy' not in code: return print("begin to test strcpy in {file_name}.".format(file_name = file_name)) # generate output file output_dir = OUT_DIR + '/' + file_name[:-2] os.makedirs(output_dir, exist_ok=True) output_file = output_dir + '/' + patch_strcpy_file print("The recv info store into the " + output_file) # get prompt from prompt.json prompt = get_prompt('strcpy', 'attack') assert (prompt) q = QueryChatGPT() response = q.query(prompt['content'].format(code = code)) print(response) # judge whether the program exists buffer overflow vulnerability due to recv func print("response info : " + response[:4].lower()) if 'yes' not in response[:4].lower(): return # store the relevant information into INPUT_DIR + "patch_recv.json" data = { "file_name": file_name, "vul_info": response[4:] # set response[4:] to skip string "yes, " } # determine the specific size to fix recv func # get patch prompt for recv prompt = get_prompt('strcpy', 'patch') assert (prompt) q = QueryChatGPT() response = q.query(prompt['content'].format(code = code)) print(response) # record modified size # match = re.search(r'size=(\d+)', response) # data['fix_size'] = int(match.group(1)) # match = re.search(r',\s*(.*)', response) # data['patch_info'] = match.group(1) if not os.path.exists(output_file): with open(output_file, 'w') as w: json.dump([], w, indent=4) with open(output_file, 'r') as r: log = json.load(r) assert (isinstance(log, list)) # insert relevant vulnerability info into json file log.append(data) with open(output_file, 'w') as w: json.dump(log, w, indent=4) # invoke patch api here # patch_recv(file_path) def exp(): print("Trying to test function normally") test_file = "edit_extract.c" # test_file = "recv_extract.c" # test_file = "dprintf_extract.c" # read code which will be used to analyze by LLM code = read_decompile_code(IN_DIR + "/" + test_file) # read danger_func.json to determine the scope of the func checked with open(DANGER_FUNC, 'r') as f: d_func = json.load(f) assert(d_func) for _d in d_func: if _d['name'] == 'dprintf': handle_dprintf(test_file, code) elif _d['name'] == 'recv': handle_recv(test_file, code) elif _d['name'] == 'strcpy': handle_strcpy(test_file, code) if __name__ == '__main__': if not llm_configured(): print('please complete llm access setup first...') exit() exp()