@@ -710,22 +710,56 @@ def _create_completion(
710710 # We want to avoid yielding any characters from
711711 # the generated text if they are part of a stop
712712 # sequence.
713- longest = 0
713+ first_stop_position = 0
714714 for s in stop_sequences :
715715 for i in range (len (s ), 0 , - 1 ):
716716 if all_text .endswith (s [:i ]):
717- if i > longest :
718- longest = i
717+ if i > first_stop_position :
718+ first_stop_position = i
719719 break
720720
721- offset = 0
721+ token_end_position = 0
722722 remaining_tokens = completion_tokens [returned_tokens :]
723723 remaining_length = len (self .detokenize (remaining_tokens ))
724724 for token in remaining_tokens :
725- offset += len (self .detokenize ([token ]))
726- # Check if stop sequence is not in the token
727- if offset >= (remaining_length - longest - 1 ):
725+ token_end_position += len (self .detokenize ([token ]))
726+ # Check if stop sequence is in the token
727+ if token_end_position >= (remaining_length - first_stop_position - 1 ):
728728 break
729+ logprobs_or_none : Optional [CompletionLogprobs ] = None
730+ if logprobs is not None :
731+ token_str = self .detokenize ([token ]).decode (
732+ "utf-8" , errors = "ignore"
733+ )
734+ text_offset = len (prompt ) + len (
735+ self .detokenize (completion_tokens [:returned_tokens ])
736+ )
737+ token_offset = len (prompt_tokens ) + returned_tokens
738+ logits = self .eval_logits [token_offset - 1 ]
739+ current_logprobs = Llama .logits_to_logprobs (logits )
740+ sorted_logprobs = list (
741+ sorted (
742+ zip (current_logprobs , range (len (current_logprobs ))),
743+ reverse = True ,
744+ )
745+ )
746+ top_logprob = {
747+ self .detokenize ([llama_cpp .llama_token (i )]).decode (
748+ "utf-8" , errors = "ignore"
749+ ): logprob
750+ for logprob , i in sorted_logprobs [:logprobs ]
751+ }
752+ top_logprob .update ({token_str : current_logprobs [int (token )]})
753+ logprobs_or_none = {
754+ "tokens" : [
755+ self .detokenize ([token ]).decode (
756+ "utf-8" , errors = "ignore"
757+ )
758+ ],
759+ "text_offset" : [text_offset ],
760+ "token_logprobs" : [sorted_logprobs [int (token )][0 ]],
761+ "top_logprobs" : [top_logprob ],
762+ }
729763 returned_tokens += 1
730764 yield {
731765 "id" : completion_id ,
@@ -738,7 +772,7 @@ def _create_completion(
738772 "utf-8" , errors = "ignore"
739773 ),
740774 "index" : 0 ,
741- "logprobs" : None ,
775+ "logprobs" : logprobs_or_none ,
742776 "finish_reason" : None ,
743777 }
744778 ],
@@ -766,13 +800,48 @@ def _create_completion(
766800 else :
767801 end = len (all_text )
768802
769- offset = 0
803+ token_end_position = 0
770804 for token in remaining_tokens :
771- offset += len (self .detokenize ([token ]))
772- if offset >= end :
805+ token_end_position += len (self .detokenize ([token ]))
806+
807+ logprobs_or_none : Optional [CompletionLogprobs ] = None
808+ if logprobs is not None :
809+ token_str = self .detokenize ([token ]).decode (
810+ "utf-8" , errors = "ignore"
811+ )
812+ text_offset = len (prompt ) + len (
813+ self .detokenize (completion_tokens [:returned_tokens ])
814+ )
815+ token_offset = len (prompt_tokens ) + returned_tokens - 1
816+ logits = self .eval_logits [token_offset ]
817+ current_logprobs = Llama .logits_to_logprobs (logits )
818+ sorted_logprobs = list (
819+ sorted (
820+ zip (current_logprobs , range (len (current_logprobs ))),
821+ reverse = True ,
822+ )
823+ )
824+ top_logprob = {
825+ self .detokenize ([llama_cpp .llama_token (i )]).decode (
826+ "utf-8" , errors = "ignore"
827+ ): logprob
828+ for logprob , i in sorted_logprobs [:logprobs ]
829+ }
830+ top_logprob .update ({token_str : current_logprobs [int (token )]})
831+ logprobs_or_none = {
832+ "tokens" : [
833+ self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
834+ ],
835+ "text_offset" : [text_offset ],
836+ "token_logprobs" : [sorted_logprobs [int (token )][0 ]],
837+ "top_logprobs" : [top_logprob ],
838+ }
839+
840+ if token_end_position >= end :
773841 last_text = self .detokenize ([token ])
774- if offset == end - 1 :
842+ if token_end_position == end - 1 :
775843 break
844+ returned_tokens += 1
776845 yield {
777846 "id" : completion_id ,
778847 "object" : "text_completion" ,
@@ -781,10 +850,10 @@ def _create_completion(
781850 "choices" : [
782851 {
783852 "text" : last_text [
784- : len (last_text ) - (offset - end )
853+ : len (last_text ) - (token_end_position - end )
785854 ].decode ("utf-8" , errors = "ignore" ),
786855 "index" : 0 ,
787- "logprobs" : None ,
856+ "logprobs" : logprobs_or_none ,
788857 "finish_reason" : finish_reason ,
789858 }
790859 ],
@@ -802,7 +871,7 @@ def _create_completion(
802871 "utf-8" , errors = "ignore"
803872 ),
804873 "index" : 0 ,
805- "logprobs" : None ,
874+ "logprobs" : logprobs_or_none ,
806875 "finish_reason" : finish_reason
807876 if returned_tokens == len (completion_tokens )
808877 else None ,
@@ -821,21 +890,27 @@ def _create_completion(
821890
822891 logprobs_or_none : Optional [CompletionLogprobs ] = None
823892 if logprobs is not None :
824- text_offset = 0
893+ text_offset = 0 if echo else len (prompt )
894+ token_offset = 0 if echo else len (prompt_tokens [1 :])
825895 text_offsets : List [int ] = []
826- token_logprobs : List [float ] = []
896+ token_logprobs : List [Optional [ float ] ] = []
827897 tokens : List [str ] = []
828- top_logprobs : List [Dict [str , float ]] = []
898+ top_logprobs : List [Optional [Dict [str , float ]]] = []
899+
900+ if echo :
901+ # Remove leading BOS token
902+ all_tokens = prompt_tokens [1 :] + completion_tokens
903+ else :
904+ all_tokens = completion_tokens
829905
830- all_tokens = prompt_tokens + completion_tokens
831906 all_token_strs = [
832907 self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
833908 for token in all_tokens
834909 ]
835910 all_logprobs = [
836911 Llama .logits_to_logprobs (list (map (float , row )))
837912 for row in self .eval_logits
838- ]
913+ ][ token_offset :]
839914 for token , token_str , logprobs_token in zip (
840915 all_tokens , all_token_strs , all_logprobs
841916 ):
@@ -848,14 +923,20 @@ def _create_completion(
848923 )
849924 )
850925 token_logprobs .append (sorted_logprobs [int (token )][0 ])
851- top_logprob = {
926+ top_logprob : Optional [ Dict [ str , float ]] = {
852927 self .detokenize ([llama_cpp .llama_token (i )]).decode (
853928 "utf-8" , errors = "ignore"
854929 ): logprob
855930 for logprob , i in sorted_logprobs [:logprobs ]
856931 }
857- top_logprob .update ({token_str : sorted_logprobs [int (token )][ 0 ]})
932+ top_logprob .update ({token_str : logprobs_token [int (token )]})
858933 top_logprobs .append (top_logprob )
934+ # Weird idosincracy of the OpenAI API where
935+ # token_logprobs and top_logprobs are null for
936+ # the first token.
937+ if echo and len (all_tokens ) > 0 :
938+ token_logprobs [0 ] = None
939+ top_logprobs [0 ] = None
859940 logprobs_or_none = {
860941 "tokens" : tokens ,
861942 "text_offset" : text_offsets ,
0 commit comments