11package org .beehive .gpullama3 .api .service ;
22
3- import org . beehive . gpullama3 . model . Model ;
4- import org .beehive .gpullama3 .inference . state . State ;
3+ import jakarta . annotation . PostConstruct ;
4+ import org .beehive .gpullama3 .Options ;
55import org .beehive .gpullama3 .inference .sampler .Sampler ;
6- import org .springframework .beans .factory .annotation .Autowired ;
6+ import org .beehive .gpullama3 .inference .state .State ;
7+ import org .beehive .gpullama3 .model .Model ;
8+ import org .beehive .gpullama3 .model .format .ChatFormat ;
9+ import org .beehive .gpullama3 .model .loader .ModelLoader ;
10+ import org .springframework .boot .ApplicationArguments ;
711import org .springframework .stereotype .Service ;
812import org .springframework .web .servlet .mvc .method .annotation .SseEmitter ;
913
10- import java .util .*;
14+ import java .util .ArrayList ;
15+ import java .util .List ;
16+ import java .util .Set ;
1117import java .util .concurrent .CompletableFuture ;
12- import java .util .function .IntConsumer ;
18+
19+ import static org .beehive .gpullama3 .inference .sampler .Sampler .selectSampler ;
20+ import static org .beehive .gpullama3 .model .loader .ModelLoader .loadModel ;
1321
1422@ Service
1523public class LLMService {
1624
17- @ Autowired
18- private ModelInitializationService initService ;
25+ private final ApplicationArguments args ;
1926
20- @ Autowired
21- private TokenizerService tokenizerService ;
27+ private Options options ;
28+ private Model model ;
2229
23- public CompletableFuture <String > generateCompletion (
24- String prompt ,
25- int maxTokens ,
26- double temperature ,
27- double topP ,
28- List <String > stopSequences ) {
30+ public LLMService (ApplicationArguments args ) {
31+ this .args = args ;
32+ }
2933
30- return CompletableFuture .supplyAsync (() -> {
31- try {
32- System .out .println ("Starting completion generation..." );
33- System .out .println ("Prompt: " + prompt .substring (0 , Math .min (50 , prompt .length ())) + "..." );
34- System .out .println ("Max tokens: " + maxTokens + ", Temperature: " + temperature );
35-
36- // Get initialized components
37- Model model = initService .getModel ();
38-
39- // Convert prompt to tokens
40- List <Integer > promptTokens = tokenizerService .encode (prompt );
41- System .out .println ("Prompt tokens: " + promptTokens .size ());
42-
43- // Convert stop sequences to token sets
44- Set <Integer > stopTokens = new HashSet <>();
45- if (stopSequences != null ) {
46- for (String stop : stopSequences ) {
47- stopTokens .addAll (tokenizerService .encode (stop ));
48- }
49- System .out .println ("Stop tokens: " + stopTokens .size ());
50- }
34+ @ PostConstruct
35+ public void init () {
36+ try {
37+ System .out .println ("Initializing LLM service..." );
38+
39+ // Step 1: Parse service options
40+ System .out .println ("Step 1: Parsing service options..." );
41+ options = Options .parseServiceOptions (args .getSourceArgs ());
42+ System .out .println ("Model path: " + options .modelPath ());
43+ System .out .println ("Context length: " + options .maxTokens ());
44+
45+ // Step 2: Load model weights
46+ System .out .println ("\n Step 2: Loading model..." );
47+ System .out .println ("Loading model from: " + options .modelPath ());
48+ model = ModelLoader .loadModel (options .modelPath (), options .maxTokens (), true );
49+ System .out .println ("✓ Model loaded successfully" );
50+ System .out .println (" Model type: " + model .getClass ().getSimpleName ());
51+ System .out .println (" Vocabulary size: " + model .configuration ().vocabularySize ());
52+ System .out .println (" Context length: " + model .configuration ().contextLength ());
53+
54+ System .out .println ("\n ✓ Model service initialization completed successfully!" );
55+ System .out .println ("=== Ready to serve requests ===\n " );
5156
52- // Create custom sampler with request-specific parameters
53- //Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis());
54- Sampler sampler = initService .getSampler ();
57+ } catch (Exception e ) {
58+ System .err .println ("✗ Failed to initialize model service: " + e .getMessage ());
59+ e .printStackTrace ();
60+ throw new RuntimeException ("Model initialization failed" , e );
61+ }
62+ }
5563
56- // Create state based on model type
57- State state = createStateForModel (model );
64+ public String generateResponse (String message , String systemMessage ) {
65+ return generateResponse (message , systemMessage , 150 , 0.7 , 0.9 );
66+ }
5867
59- // Generate tokens using your existing method
60- List <Integer > generatedTokens = model .generateTokens (
61- state ,
62- 0 ,
63- promptTokens ,
64- stopTokens ,
65- maxTokens ,
66- sampler ,
67- false ,
68- token -> {} // No callback for non-streaming
69- );
68+ public String generateResponse (String message , String systemMessage , int maxTokens , double temperature , double topP ) {
69+ try {
70+ // Create sampler and state like runInstructOnce
71+ Sampler sampler = selectSampler (model .configuration ().vocabularySize (), (float ) temperature , (float ) topP , System .currentTimeMillis ());
72+ State state = model .createNewState ();
7073
71- // Decode tokens back to text
72- String result = tokenizerService .decode (generatedTokens );
73- System .out .println ("Generated " + generatedTokens .size () + " tokens" );
74- System .out .println ("Completion finished successfully" );
74+ // Use model's ChatFormat
75+ ChatFormat chatFormat = model .chatFormat ();
76+ List <Integer > promptTokens = new ArrayList <>();
7577
76- return result ;
78+ // Add begin of text if needed
79+ if (model .shouldAddBeginOfText ()) {
80+ promptTokens .add (chatFormat .getBeginOfText ());
81+ }
7782
78- } catch (Exception e ) {
79- System .err .println ("Error generating completion: " + e .getMessage ());
80- e .printStackTrace ();
81- throw new RuntimeException ("Error generating completion" , e );
83+ // Add system message properly formatted
84+ if (model .shouldAddSystemPrompt () && systemMessage != null && !systemMessage .trim ().isEmpty ()) {
85+ promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , systemMessage )));
8286 }
83- });
84- }
8587
86- public void generateStreamingCompletion (
87- String prompt ,
88- int maxTokens ,
89- double temperature ,
90- double topP ,
91- List <String > stopSequences ,
92- SseEmitter emitter ) {
88+ // Add user message properly formatted
89+ promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , message )));
90+ promptTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
91+
92+ // Handle reasoning tokens if needed (for Deepseek-R1-Distill-Qwen)
93+ if (model .shouldIncludeReasoning ()) {
94+ List <Integer > thinkStartTokens = model .tokenizer ().encode ("<think>\n " , model .tokenizer ().getSpecialTokens ().keySet ());
95+ promptTokens .addAll (thinkStartTokens );
96+ }
97+
98+ // Use proper stop tokens from chat format
99+ Set <Integer > stopTokens = chatFormat .getStopTokens ();
100+
101+ long startTime = System .currentTimeMillis ();
102+
103+ // Use CPU path for now (GPU path disabled as noted)
104+ List <Integer > generatedTokens = model .generateTokens (
105+ state , 0 , promptTokens , stopTokens , maxTokens , sampler , false , token -> {}
106+ );
93107
108+ // Remove stop tokens if present
109+ if (!generatedTokens .isEmpty () && stopTokens .contains (generatedTokens .getLast ())) {
110+ generatedTokens .removeLast ();
111+ }
112+
113+ long duration = System .currentTimeMillis () - startTime ;
114+ double tokensPerSecond = generatedTokens .size () * 1000.0 / duration ;
115+ System .out .printf ("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n" ,
116+ generatedTokens .size (), duration , tokensPerSecond );
117+
118+
119+ String responseText = model .tokenizer ().decode (generatedTokens );
120+
121+ // Add reasoning prefix for non-streaming if needed
122+ if (model .shouldIncludeReasoning ()) {
123+ responseText = "<think>\n " + responseText ;
124+ }
125+
126+ return responseText ;
127+
128+ } catch (Exception e ) {
129+ System .err .println ("FAILED " + e .getMessage ());
130+ throw new RuntimeException ("Failed to generate response" , e );
131+ }
132+ }
133+
134+ public void generateStreamingResponse (String message , String systemMessage , SseEmitter emitter ) {
94135 CompletableFuture .runAsync (() -> {
95136 try {
96- System .out .println ("Starting streaming completion generation..." );
97-
98- Model model = initService .getModel ();
137+ Sampler sampler = selectSampler (model .configuration ().vocabularySize (), 0.7f , 0.9f , System .currentTimeMillis ());
138+ State state = model .createNewState ();
99139
100- List <Integer > promptTokens = tokenizerService .encode (prompt );
140+ // Use proper chat format like in runInstructOnce
141+ ChatFormat chatFormat = model .chatFormat ();
142+ List <Integer > promptTokens = new ArrayList <>();
101143
102- Set <Integer > stopTokens = new HashSet <>();
103- if (stopSequences != null ) {
104- for (String stop : stopSequences ) {
105- stopTokens .addAll (tokenizerService .encode (stop ));
106- }
144+ if (model .shouldAddBeginOfText ()) {
145+ promptTokens .add (chatFormat .getBeginOfText ());
107146 }
108147
109- //Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis());
110- Sampler sampler = initService .getSampler ();
111- State state = createStateForModel (model );
112-
113- final int [] tokenCount = {0 };
148+ if (model .shouldAddSystemPrompt () && systemMessage != null && !systemMessage .trim ().isEmpty ()) {
149+ promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , systemMessage )));
150+ }
114151
115- // Streaming callback
116- IntConsumer tokenCallback = token -> {
117- try {
118- String tokenText = tokenizerService .decode (List .of (token ));
119- tokenCount [0 ]++;
152+ promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , message )));
153+ promptTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
120154
121- String eventData = String .format (
122- "data: {\" choices\" :[{\" text\" :\" %s\" ,\" index\" :0,\" finish_reason\" :null}]}\n \n " ,
123- escapeJson (tokenText )
124- );
155+ // Handle reasoning tokens for streaming
156+ if (model .shouldIncludeReasoning ()) {
157+ List <Integer > thinkStartTokens = model .tokenizer ().encode ("<think>\n " , model .tokenizer ().getSpecialTokens ().keySet ());
158+ promptTokens .addAll (thinkStartTokens );
159+ emitter .send (SseEmitter .event ().data ("<think>\n " )); // Output immediately
160+ }
125161
126- emitter . send ( SseEmitter . event (). data ( eventData ) );
162+ Set < Integer > stopTokens = chatFormat . getStopTokens ( );
127163
128- if (tokenCount [0 ] % 10 == 0 ) {
129- System .out .println ("Streamed " + tokenCount [0 ] + " tokens" );
164+ final int [] tokenCount = {0 };
165+ long startTime = System .currentTimeMillis ();
166+ List <Integer > generatedTokens = model .generateTokens (
167+ state , 0 , promptTokens , stopTokens , 150 , sampler , false ,
168+ token -> {
169+ try {
170+ // Only display tokens that should be displayed (like in your original)
171+ if (model .tokenizer ().shouldDisplayToken (token )) {
172+ String tokenText = model .tokenizer ().decode (List .of (token ));
173+ emitter .send (SseEmitter .event ().data (tokenText ));
174+ tokenCount [0 ]++;
175+ }
176+ } catch (Exception e ) {
177+ emitter .completeWithError (e );
178+ }
130179 }
180+ );
131181
132- } catch (Exception e ) {
133- System .err .println ("Error in streaming callback: " + e .getMessage ());
134- emitter .completeWithError (e );
135- }
136- };
137-
138- model .generateTokens (state , 0 , promptTokens , stopTokens , maxTokens , sampler , false , tokenCallback );
182+ long duration = System .currentTimeMillis () - startTime ;
183+ double tokensPerSecond = tokenCount [0 ] * 1000.0 / duration ;
184+ System .out .printf ("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n" ,
185+ tokenCount [0 ], duration , tokensPerSecond );
139186
140- // Send completion event
141- emitter .send (SseEmitter .event ().data ("data: [DONE]\n \n " ));
187+ emitter .send (SseEmitter .event ().data ("[DONE]" ));
142188 emitter .complete ();
143189
144- System .out .println ("Streaming completion finished. Total tokens: " + tokenCount [0 ]);
145-
146190 } catch (Exception e ) {
147- System .err .println ("Error in streaming generation: " + e .getMessage ());
148- e .printStackTrace ();
191+ System .err .println ("FAILED " + e .getMessage ());
149192 emitter .completeWithError (e );
150193 }
151194 });
152195 }
153196
154- /**
155- * Create appropriate State subclass based on the model type
156- */
157- private State createStateForModel (Model model ) {
158- try {
159- return model .createNewState ();
160- } catch (Exception e ) {
161- throw new RuntimeException ("Failed to create state for model" , e );
197+ // Getters for other services to access the initialized components
198+ public Options getOptions () {
199+ if (options == null ) {
200+ throw new IllegalStateException ("Model service not initialized yet" );
162201 }
202+ return options ;
163203 }
164204
165- private String escapeJson (String str ) {
166- if (str == null ) return "" ;
167- return str .replace ("\" " , "\\ \" " )
168- .replace ("\n " , "\\ n" )
169- .replace ("\r " , "\\ r" )
170- .replace ("\t " , "\\ t" )
171- .replace ("\\ " , "\\ \\ " );
205+ public Model getModel () {
206+ if (model == null ) {
207+ throw new IllegalStateException ("Model service not initialized yet" );
208+ }
209+ return model ;
172210 }
173211}
0 commit comments