@@ -45,10 +45,10 @@ def setup_environment_variables():
4545 sys .exit (1 )
4646
4747 data_dir = os .getenv ("DATA_DIR" , "/home/data/kv_cache" )
48- data_dir = input (
49- "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
50- )
5148 if not os .path .isdir (data_dir ):
49+ data_dir = input (
50+ "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
51+ )
5252 create = input (f"Directory { data_dir } dose not exist. Create it? (Y/n): " )
5353 if create .lower () == "y" :
5454 os .makedirs (data_dir , exist_ok = True )
@@ -87,7 +87,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
8787 model = model ,
8888 kv_transfer_config = ktc ,
8989 max_model_len = 32768 ,
90- gpu_memory_utilization = 0.6 ,
90+ gpu_memory_utilization = 0.8 ,
9191 max_num_batched_tokens = 30000 ,
9292 block_size = 128 ,
9393 enforce_eager = True ,
@@ -111,10 +111,14 @@ def print_output(
111111 start = time .time ()
112112 outputs = llm .generate (prompt , sampling_params )
113113 print ("-" * 50 )
114+ lines = []
114115 for output in outputs :
115116 generated_text = output .outputs [0 ].text
116117 print (f"Generated text: { generated_text !r} " )
118+ lines .append (generated_text + "\n " )
117119 print (f"Generation took { time .time () - start :.2f} seconds, { req_str } request done." )
120+ with open ("./newest_out.txt" , "w" ) as f :
121+ f .writelines (lines )
118122 print ("-" * 50 )
119123
120124
@@ -140,24 +144,24 @@ def get_prompt(prompt):
140144
141145 with build_llm_with_uc (module_path , name , model ) as llm :
142146 prompts = []
143- batch_size = 5
147+ batch_size = 20
144148 assert os .path .isfile (
145149 path_to_dataset
146150 ), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`"
147151 with open (path_to_dataset , "r" ) as f :
148- for _ in range (batch_size ):
149- line = f .readline ()
150- if not line :
151- break
152- data = json .loads (line )
153- context = data ["context" ]
154- question = data ["input" ]
155- prompts .append (get_prompt (f"{ context } \n \n { question } " ))
156-
157- sampling_params = SamplingParams (temperature = 0 , top_p = 0.95 , max_tokens = 100 )
152+ lines = f .readlines ()
153+ for i in range (batch_size ):
154+ line = lines [i ]
155+ data = json .loads (line )
156+ context = data ["context" ]
157+ question = data ["input" ]
158+ prompts .append (get_prompt (f"{ context } \n \n { question } " ))
159+
160+ sampling_params = SamplingParams (
161+ temperature = 0 , top_p = 0.95 , max_tokens = 256 , ignore_eos = False
162+ )
158163
159164 print_output (llm , prompts , sampling_params , "first" )
160- print_output (llm , prompts , sampling_params , "second" )
161165
162166
163167if __name__ == "__main__" :
0 commit comments