|
2 | 2 | The InferenceUtil is taken from https://github.com/FudanSELab/ClassEval/blob/master/generation/inference_util.py as we want to keep faith with the original implementation. |
3 | 3 | """ |
4 | 4 |
|
5 | | -import re |
6 | 5 | from enum import Enum |
7 | 6 |
|
8 | 7 | class ModelName(Enum): |
@@ -51,95 +50,7 @@ def generate_prompt(instruction, model_name): |
51 | 50 |
|
52 | 51 | ### Response: |
53 | 52 | """ |
54 | | - |
55 | | - # @staticmethod |
56 | | - # def get_leading_spaces(string): |
57 | | - # return len(string) - len(string.lstrip()) |
58 | | - |
59 | | - # @staticmethod |
60 | | - # def del_segment_notation(code): |
61 | | - # pattern = r'(""".*?""")' |
62 | | - # result = re.sub(pattern, '', code, flags = re.DOTALL) |
63 | | - # return result |
64 | | - |
65 | | - # @staticmethod |
66 | | - # def get_method_signature(code, method_name): |
67 | | - # method_def_prefix = "def " + method_name + '(' |
68 | | - # code_segment = code.split('):') |
69 | | - # for segment in code_segment: |
70 | | - # if method_def_prefix in segment: |
71 | | - # return " " + segment + "):" |
72 | | - # return "" |
73 | | - |
74 | | - # @staticmethod |
75 | | - # def add_desc_to_init(desc, class_init): |
76 | | - # class_init_list = class_init.split('\n') |
77 | | - # class_init_list[0] += " \n" + desc |
78 | | - # class_init = '\n'.join(class_init_list) |
79 | | - # return class_init |
80 | | - |
81 | | - # @staticmethod |
82 | | - # def extract_method_code(code, method_name): |
83 | | - # # extract code of method {method_name} from {code} |
84 | | - # output_split_identifier_list = ["### Response:", "@@ Response:", "[/INST]"] |
85 | | - # for identifier in output_split_identifier_list: |
86 | | - # if identifier in code: |
87 | | - # code = code.split(identifier)[1] |
88 | | - # break |
89 | | - |
90 | | - # pattern_list = [r"```python(.*?)```", r"\[PYTHON\](.*?)\[/PYTHON\]"] |
91 | | - # for pattern in pattern_list: |
92 | | - # code_part = re.findall(pattern, code, re.S) |
93 | | - # if code_part: |
94 | | - # code = code_part[0] |
95 | | - # break |
96 | | - |
97 | | - # code_list = code.split('\n') |
98 | | - |
99 | | - # method_code_list = [] |
100 | | - # method_def_prefix = "def " + method_name + '(' |
101 | | - # skip_line_list = ["```", '\r'] |
102 | | - # # extract generated method code corresponding method_name, the strategy is to find the line |
103 | | - # # has "def methodname(...)" and following lines have more leading spaces than the first "def" line |
104 | | - # for i, line in enumerate(code_list): |
105 | | - # if method_def_prefix in line: |
106 | | - # method_code_list = code_list[i:] |
107 | | - # break |
108 | | - |
109 | | - # if len(method_code_list) == 0: |
110 | | - # return "" |
111 | | - |
112 | | - # for i, line in enumerate(method_code_list): |
113 | | - # if line in skip_line_list: |
114 | | - # method_code_list[i] = "" |
115 | | - |
116 | | - # if InferenceUtil.get_leading_spaces(method_code_list[1]) - InferenceUtil.get_leading_spaces(method_code_list[0]) > 4: |
117 | | - # method_code_list[0] = " " * 4 + method_code_list[0] |
118 | | - |
119 | | - # first_line_leading_space = InferenceUtil.get_leading_spaces(method_code_list[0]) |
120 | | - # for i, line in enumerate(method_code_list[1:]): |
121 | | - # if InferenceUtil.get_leading_spaces(line) <= first_line_leading_space and len(line) > 0: |
122 | | - # method_code_list = method_code_list[:i + 1] |
123 | | - # break |
124 | | - |
125 | | - # for i, line in enumerate(method_code_list): |
126 | | - # method_code_list[i] = ' ' * (4 - first_line_leading_space) + line |
127 | 53 |
|
128 | | - # if 'self' not in method_code_list[0] and 'cls' not in method_code_list[0]: |
129 | | - # method_code_list.insert(0, ' ' * 4 + "@staticmethod") |
130 | | - |
131 | | - # line_notation_mark = 0 |
132 | | - # for line in method_code_list: |
133 | | - # if line == " " * 8 + "\"\"\"" or line == " " * 4 + "\"\"\"": |
134 | | - # line_notation_mark = line_notation_mark + 1 |
135 | | - # if line_notation_mark % 2 == 1: |
136 | | - # method_code_list.append(" " * 8 + "\"\"\"") |
137 | | - # method_code_list.append(" " * 8 + "pass") |
138 | | - |
139 | | - # method_code = '\n'.join(method_code_list) |
140 | | - # method_code = method_code.rstrip() + '\n' |
141 | | - # return method_code |
142 | | - |
143 | 54 |
|
144 | 55 | # FOR NOW, We default to using the prompts that work for GPT3.5 and the holistic strategy |
145 | 56 | def construct_prompt(info, model_name = ModelName.GPT_3_5, strategy = GenerationStrategy.Holistic): |
|
0 commit comments