1+ import os
12from typing import List
23import random
34import sys
@@ -39,10 +40,8 @@ def shard_weights(env, weights, weight_shardings):
3940 sharded = {}
4041 for key , val in weights .items ():
4142 sharding = env .sharding_by_axis (weight_shardings .get (key , - 1 ))
42- print ("SHARDING" , key , sharding )
4343 with jax .default_device (jax .devices ("cpu" )[0 ]):
4444 arr = torch_xla2 .tensor .t2j (val )
45-
4645 arr = jax .device_put (arr , sharding )
4746 sharded [key ] = torchjax .to_torch (arr )
4847 return sharded
@@ -57,17 +56,16 @@ def create_engine(devices):
5756 FLAGS .override_batch_size ,
5857 FLAGS .max_input_length ,
5958 FLAGS .max_output_length ,
60- quant_config .enable_weight_quantization ,
6159 )
6260 tokenizer = AutoTokenizer .from_pretrained (FLAGS .model_id )
6361 env = environment .JetEngineEnvironment (env_data )
6462 env .hf_tokenizer = tokenizer
6563 model = fetch_models .instantiate_model_from_repo_id (FLAGS .model_id , env )
64+ # NOTE: this is assigned later because, the model should be constructed
65+ # as a float model first then quantized
66+ env .quant_config = quant_config
6667 if quant_config .enable_weight_quantization :
6768 quantize_model .quantize_model (model , quant_config )
68- print ("====== model =======" )
69- print (model )
70-
7169 weight_shardings = model .get_sharding_annotations ()
7270 sharded_weights = shard_weights (env , model .state_dict (), weight_shardings )
7371 env_data .quant_config = quant_config
@@ -202,7 +200,7 @@ def interactive():
202200 "<s>[INST] <<SYS>>\n You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n <</SYS>>\n \n Continue the following story.\n \n Kay didn't have shoes that fit her feet properly. She only wore sneakers, because the \n Choose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]" ,
203201 ]
204202 for prompt in prompts :
205- slot = random .randint (0 , FLAGS .batch_size - 1 )
203+ slot = random .randint (0 , FLAGS .override_batch_size - 1 )
206204 tokens , true_length = tokenizer .encode (prompt )
207205
208206 print (f"---- Input prompts are: { prompt } " )
@@ -330,10 +328,10 @@ def benchmark_offline():
330328 decode_time_ms = sum (dec_times [2 :]) * 1000 / 8
331329
332330 largest_prefill = max (prefill_times .items ())
333- print ("MAX tokens:" , FLAGS .batch_size / avg_decode_times )
331+ print ("MAX tokens:" , FLAGS .override_batch_size / avg_decode_times )
334332
335- time2 = (FLAGS .batch_size * FLAGS .max_decode_length ) / (
336- FLAGS .batch_size * largest_prefill [1 ]
333+ time2 = (FLAGS .override_batch_size * FLAGS .max_decode_length ) / (
334+ FLAGS .override_batch_size * largest_prefill [1 ]
337335 + FLAGS .max_decode_length * avg_decode_times
338336 )
339337 print ("MAX tokens 2:" , time2 )
@@ -351,6 +349,8 @@ def main():
351349
352350 def main_real (argv ):
353351 """Entry point"""
352+ jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
353+ os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
354354 if len (argv ) < 2 :
355355 print ("Invalid arguments. please specify 'list' or 'serve'" )
356356
0 commit comments