init commit
This commit is contained in:
248
test.py
Normal file
248
test.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import warnings
|
||||
from chat import QueryChatGPT, llm_configured, load_config
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
|
||||
DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
PROMPT_PATH = os.path.join(DIR, 'prompt.json')
|
||||
DANGER_FUNC = os.path.join(DIR, 'danger_func.json')
|
||||
IN_DIR = os.path.join(DIR, 'input')
|
||||
OUT_DIR = os.path.join(DIR, 'output')
|
||||
|
||||
def get_prompt(name: str, _type: str, prompt_path: str = PROMPT_PATH) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Access the prompt
|
||||
|
||||
Args:
|
||||
name: the name of the prompt
|
||||
_type: the type of the prompt
|
||||
prompt_path: the path of the prompt file
|
||||
|
||||
Returns:
|
||||
a dict containing two keys: 'role' and 'content'
|
||||
"""
|
||||
|
||||
prompts = None
|
||||
with open(prompt_path, 'r') as f:
|
||||
prompts = json.load(f)
|
||||
assert(prompts)
|
||||
|
||||
for _p in prompts:
|
||||
if _p['name'] == name and _p['type'] == _type:
|
||||
return _p['prompt']
|
||||
return None
|
||||
|
||||
def read_decompile_code(file_path: str) -> Optional[str]:
|
||||
if not os.path.exists(file_path):
|
||||
warnings.warn("Fail to find {file_path}!".format(file_path, file_path))
|
||||
sys.exit(1)
|
||||
with open(file_path, 'r') as r:
|
||||
code_data = r.read()
|
||||
return code_data
|
||||
|
||||
data = {
|
||||
"name": "Henry",
|
||||
"age": 30,
|
||||
"fuzzing": True,
|
||||
"languages": ["C", "Python", "Assembly"]
|
||||
}
|
||||
|
||||
def handle_dprintf(file_name: str, code: str, patch_dprintf_file: str = 'patch_dprintf.json'):
|
||||
if 'dprintf' not in code:
|
||||
return
|
||||
|
||||
print("begin to test dprintf in {file_name}.".format(file_name = file_name))
|
||||
# generate output file
|
||||
output_dir = OUT_DIR + '/' + file_name[:-2]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
output_file = output_dir + '/' + patch_dprintf_file
|
||||
print("The dprintf info store into the " + output_file)
|
||||
|
||||
# get prompt from prompt.json
|
||||
prompt = get_prompt('dprintf', 'attack')
|
||||
assert (prompt)
|
||||
q = QueryChatGPT()
|
||||
response = q.query(prompt['content'].format(code = code))
|
||||
|
||||
# judge whether the program exists dprintf format string vulnerability
|
||||
print("response info : " + response[:4].lower())
|
||||
if 'yes' not in response[:4].lower():
|
||||
return
|
||||
|
||||
# store the relevant information into INPUT_DIR + "patch_dprintf.json"
|
||||
data = {
|
||||
"file_path": file_name,
|
||||
"vul_info": response[4:] # set response[4:] to skip string "yes, "
|
||||
}
|
||||
|
||||
if not os.path.exists(output_file):
|
||||
with open(output_file, 'w') as w:
|
||||
json.dump([], w, indent=4)
|
||||
|
||||
with open(output_file, 'r') as r:
|
||||
log = json.load(r)
|
||||
assert (isinstance(log, list))
|
||||
|
||||
# insert relevant vulnerability info into json file
|
||||
log.append(data)
|
||||
|
||||
with open(output_file, 'w') as w:
|
||||
json.dump(log, w, indent=4)
|
||||
|
||||
# invoke patch api here
|
||||
# patch_dprintf(file_path)
|
||||
|
||||
def handle_recv(file_name: str, code: str, patch_recv_file: str = 'patch_recv.json'):
|
||||
if 'recv' not in code:
|
||||
return
|
||||
|
||||
print("begin to test recv in {file_name}.".format(file_name = file_name))
|
||||
# generate output file
|
||||
output_dir = OUT_DIR + '/' + file_name[:-2]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
output_file = output_dir + '/' + patch_recv_file
|
||||
print("The recv info store into the " + output_file)
|
||||
|
||||
# get prompt from prompt.json
|
||||
prompt = get_prompt('recv', 'attack')
|
||||
assert (prompt)
|
||||
q = QueryChatGPT()
|
||||
response = q.query(prompt['content'].format(code = code))
|
||||
print(response)
|
||||
# judge whether the program exists buffer overflow vulnerability due to recv func
|
||||
print("response info : " + response[:4].lower())
|
||||
if 'yes' not in response[:4].lower():
|
||||
return
|
||||
|
||||
# store the relevant information into INPUT_DIR + "patch_recv.json"
|
||||
data = {
|
||||
"file_name": file_name,
|
||||
"vul_info": response[4:] # set response[4:] to skip string "yes, "
|
||||
}
|
||||
|
||||
# determine the specific size to fix recv func
|
||||
# get patch prompt for recv
|
||||
prompt = get_prompt('recv', 'patch')
|
||||
assert (prompt)
|
||||
q = QueryChatGPT()
|
||||
response = q.query(prompt['content'].format(code = code))
|
||||
print(response)
|
||||
|
||||
# record modified size
|
||||
match = re.search(r'size=(\d+)', response)
|
||||
data['fix_size'] = int(match.group(1))
|
||||
|
||||
match = re.search(r',\s*(.*)', response)
|
||||
data['patch_info'] = match.group(1)
|
||||
|
||||
if not os.path.exists(output_file):
|
||||
with open(output_file, 'w') as w:
|
||||
json.dump([], w, indent=4)
|
||||
|
||||
with open(output_file, 'r') as r:
|
||||
log = json.load(r)
|
||||
assert (isinstance(log, list))
|
||||
|
||||
# insert relevant vulnerability info into json file
|
||||
log.append(data)
|
||||
|
||||
with open(output_file, 'w') as w:
|
||||
json.dump(log, w, indent=4)
|
||||
|
||||
# invoke patch api here
|
||||
# patch_recv(file_path)
|
||||
|
||||
def handle_strcpy(file_name: str, code: str, patch_strcpy_file: str = 'patch_strcpy.json'):
|
||||
if 'strcpy' not in code:
|
||||
return
|
||||
|
||||
print("begin to test strcpy in {file_name}.".format(file_name = file_name))
|
||||
# generate output file
|
||||
output_dir = OUT_DIR + '/' + file_name[:-2]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
output_file = output_dir + '/' + patch_strcpy_file
|
||||
print("The recv info store into the " + output_file)
|
||||
|
||||
# get prompt from prompt.json
|
||||
prompt = get_prompt('strcpy', 'attack')
|
||||
assert (prompt)
|
||||
q = QueryChatGPT()
|
||||
response = q.query(prompt['content'].format(code = code))
|
||||
print(response)
|
||||
# judge whether the program exists buffer overflow vulnerability due to recv func
|
||||
print("response info : " + response[:4].lower())
|
||||
if 'yes' not in response[:4].lower():
|
||||
return
|
||||
|
||||
# store the relevant information into INPUT_DIR + "patch_recv.json"
|
||||
data = {
|
||||
"file_name": file_name,
|
||||
"vul_info": response[4:] # set response[4:] to skip string "yes, "
|
||||
}
|
||||
|
||||
# determine the specific size to fix recv func
|
||||
# get patch prompt for recv
|
||||
prompt = get_prompt('strcpy', 'patch')
|
||||
assert (prompt)
|
||||
q = QueryChatGPT()
|
||||
response = q.query(prompt['content'].format(code = code))
|
||||
print(response)
|
||||
|
||||
# record modified size
|
||||
# match = re.search(r'size=(\d+)', response)
|
||||
# data['fix_size'] = int(match.group(1))
|
||||
|
||||
# match = re.search(r',\s*(.*)', response)
|
||||
# data['patch_info'] = match.group(1)
|
||||
|
||||
if not os.path.exists(output_file):
|
||||
with open(output_file, 'w') as w:
|
||||
json.dump([], w, indent=4)
|
||||
|
||||
with open(output_file, 'r') as r:
|
||||
log = json.load(r)
|
||||
assert (isinstance(log, list))
|
||||
|
||||
# insert relevant vulnerability info into json file
|
||||
log.append(data)
|
||||
|
||||
with open(output_file, 'w') as w:
|
||||
json.dump(log, w, indent=4)
|
||||
|
||||
# invoke patch api here
|
||||
# patch_recv(file_path)
|
||||
|
||||
def exp():
|
||||
print("Trying to test function normally")
|
||||
|
||||
test_file = "edit_extract.c"
|
||||
# test_file = "recv_extract.c"
|
||||
# test_file = "dprintf_extract.c"
|
||||
|
||||
# read code which will be used to analyze by LLM
|
||||
code = read_decompile_code(IN_DIR + "/" + test_file)
|
||||
|
||||
# read danger_func.json to determine the scope of the func checked
|
||||
with open(DANGER_FUNC, 'r') as f:
|
||||
d_func = json.load(f)
|
||||
assert(d_func)
|
||||
|
||||
for _d in d_func:
|
||||
if _d['name'] == 'dprintf':
|
||||
handle_dprintf(test_file, code)
|
||||
elif _d['name'] == 'recv':
|
||||
handle_recv(test_file, code)
|
||||
elif _d['name'] == 'strcpy':
|
||||
handle_strcpy(test_file, code)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not llm_configured():
|
||||
print('please complete llm access setup first...')
|
||||
exit()
|
||||
exp()
|
||||
Reference in New Issue
Block a user