11mod backend;
2+ mod llamacpp;
3+ mod quantize;
4+
5+ use quantize:: QuantizeType ;
26
37use backend:: {
48 BackendError , LlamacppBackend , LlamacppConfig , LlamacppGGMLType , LlamacppNuma ,
59 LlamacppSplitMode ,
610} ;
711use clap:: Parser ;
12+ use hf_hub:: api:: tokio:: ApiBuilder ;
13+ use hf_hub:: { Repo , RepoType } ;
14+ use std:: path:: Path ;
815use text_generation_router:: { logging, server, usage_stats} ;
916use thiserror:: Error ;
10- use tokenizers:: { FromPretrainedParameters , Tokenizer } ;
17+ use tokenizers:: Tokenizer ;
18+ use tokio:: process:: Command ;
1119use tokio:: sync:: oneshot:: error:: RecvError ;
1220use tracing:: { error, warn} ;
1321
@@ -25,7 +33,7 @@ struct Args {
2533
2634 /// Path to the GGUF model file for inference.
2735 #[ clap( long, env) ]
28- model_gguf : String , // TODO Option() with hf->gguf & quantize
36+ model_gguf : Option < String > ,
2937
3038 /// Number of threads to use for generation.
3139 #[ clap( long, env) ]
@@ -53,19 +61,19 @@ struct Args {
5361
5462 /// Use memory mapping for the model.
5563 #[ clap( long, env) ]
56- use_mmap : bool ,
64+ disable_mmap : bool ,
5765
5866 /// Use memory locking to prevent swapping.
5967 #[ clap( long, env) ]
6068 use_mlock : bool ,
6169
6270 /// Enable offloading of KQV operations to the GPU.
6371 #[ clap( long, env) ]
64- offload_kqv : bool ,
72+ disable_offload_kqv : bool ,
6573
6674 /// Enable flash attention for faster inference. (EXPERIMENTAL)
6775 #[ clap( long, env) ]
68- flash_attention : bool ,
76+ disable_flash_attention : bool ,
6977
7078 /// Data type used for K cache.
7179 #[ clap( default_value = "f16" , value_enum, long, env) ]
@@ -194,35 +202,80 @@ async fn main() -> Result<(), RouterError> {
194202 ) ) ;
195203 }
196204
197- // TODO: check if we use the same cache of Server
198- // check if llamacpp is faster
199- let tokenizer = {
200- let token = std:: env:: var ( "HF_TOKEN" )
201- . or_else ( |_| std:: env:: var ( "HUGGING_FACE_HUB_TOKEN" ) )
202- . ok ( ) ;
203- let params = FromPretrainedParameters {
204- revision : args. revision . clone ( ) ,
205- token,
206- ..Default :: default ( )
207- } ;
208- Tokenizer :: from_pretrained ( args. model_id . clone ( ) , Some ( params) ) ?
205+ let api_builder = || {
206+ let mut builder = ApiBuilder :: new ( ) . with_progress ( true ) ;
207+
208+ if let Ok ( cache_dir) = std:: env:: var ( "HUGGINGFACE_HUB_CACHE" ) {
209+ builder = builder. with_cache_dir ( cache_dir. into ( ) ) ;
210+ }
211+ if let Ok ( token) = std:: env:: var ( "HF_TOKEN" ) {
212+ builder = builder. with_token ( token. into ( ) ) ;
213+ }
214+ if let Ok ( origin) = std:: env:: var ( "HF_HUB_USER_AGENT_ORIGIN" ) {
215+ builder = builder. with_user_agent ( "origin" , origin. as_str ( ) ) ;
216+ }
217+ builder
218+ } ;
219+ let api_repo = api_builder ( ) . build ( ) ?. repo ( Repo :: with_revision (
220+ args. model_id . clone ( ) ,
221+ RepoType :: Model ,
222+ args. revision . clone ( ) ,
223+ ) ) ;
224+
225+ let tokenizer_path = api_repo. get ( "tokenizer.json" ) . await ?;
226+ let tokenizer = Tokenizer :: from_file ( & tokenizer_path) ?;
227+
228+ let model_gguf = if let Some ( model_gguf) = args. model_gguf {
229+ model_gguf
230+ } else {
231+ let model_gguf = format ! ( "models/{}/model.gguf" , args. model_id) ;
232+ let model_gguf_path = Path :: new ( & model_gguf) ;
233+
234+ if !model_gguf_path. exists ( ) {
235+ let tmp_gguf = "models/tmp.gguf" ;
236+
237+ if let Some ( parent) = Path :: new ( model_gguf_path) . parent ( ) {
238+ std:: fs:: create_dir_all ( parent) ?;
239+ }
240+ let cache_path = tokenizer_path. parent ( ) . unwrap ( ) ;
241+
242+ for sibling in api_repo. info ( ) . await ?. siblings {
243+ let _ = api_repo. get ( & sibling. rfilename ) . await ?;
244+ }
245+ let status = Command :: new ( "convert_hf_to_gguf.py" )
246+ . arg ( "--outfile" )
247+ . arg ( tmp_gguf)
248+ . arg ( cache_path)
249+ . spawn ( ) ?
250+ . wait ( )
251+ . await ?;
252+
253+ if !status. success ( ) {
254+ let exit_code = status. code ( ) . unwrap_or ( -1 ) ;
255+ error ! ( "Failed to generate GGUF, exit code: {}" , exit_code) ;
256+ return Err ( RouterError :: CommandError ( exit_code) ) ;
257+ }
258+ quantize:: model ( tmp_gguf, & model_gguf, QuantizeType :: MostlyQ4_0 , n_threads)
259+ . map_err ( RouterError :: QuantizeError ) ?;
260+ }
261+ model_gguf
209262 } ;
210263
211264 let ( backend, ok, shutdown) = LlamacppBackend :: new (
212265 LlamacppConfig {
213- model_gguf : args . model_gguf ,
266+ model_gguf,
214267 n_threads,
215268 n_threads_batch,
216269 n_gpu_layers : args. n_gpu_layers ,
217270 split_mode : args. split_mode ,
218271 defrag_threshold : args. defrag_threshold ,
219272 numa : args. numa ,
220- use_mmap : args. use_mmap ,
273+ use_mmap : ! args. disable_mmap ,
221274 use_mlock : args. use_mlock ,
222- flash_attention : args. flash_attention ,
275+ flash_attention : ! args. disable_flash_attention ,
223276 type_k : args. type_k ,
224277 type_v : args. type_v ,
225- offload_kqv : args. offload_kqv ,
278+ offload_kqv : ! args. disable_offload_kqv ,
226279 max_batch_total_tokens,
227280 max_physical_batch_total_tokens,
228281 max_batch_size,
@@ -281,4 +334,14 @@ enum RouterError {
281334 WebServer ( #[ from] server:: WebServerError ) ,
282335 #[ error( "Recv error: {0}" ) ]
283336 RecvError ( #[ from] RecvError ) ,
337+ #[ error( "Io error: {0}" ) ]
338+ IoError ( #[ from] std:: io:: Error ) ,
339+ #[ error( "Var error: {0}" ) ]
340+ VarError ( #[ from] std:: env:: VarError ) ,
341+ #[ error( "Quantize error: {0}" ) ]
342+ QuantizeError ( String ) ,
343+ #[ error( "Command error: {0}" ) ]
344+ CommandError ( i32 ) ,
345+ #[ error( "HF hub error: {0}" ) ]
346+ HubError ( #[ from] hf_hub:: api:: tokio:: ApiError ) ,
284347}
0 commit comments