1+ # -*- coding: utf-8 -*-
2+ # Copyright (C) 2023-2024 Intel Corporation
3+ # SPDX-License-Identifier: Apache-2.0
4+ # flake8: noqa
5+ import time
6+ import torch
7+ import logging as log
8+ from torch import nn
9+ from typing import Optional , Tuple , Union
10+ from transformers .generation .stopping_criteria import (
11+ StoppingCriteriaList ,
12+ validate_stopping_criteria ,
13+ )
14+ from transformers .generation .logits_process import LogitsProcessorList
15+ from transformers .generation .streamers import BaseStreamer
16+ from transformers .utils import ModelOutput
17+ from transformers .generation .configuration_utils import GenerationConfig
18+ import llm_bench_utils .hook_greedy_search as hook_greedy
19+
20+
21+ logger = log .getLogger (__name__ )
22+
23+
24+ class GenerateDecoderOnlyOutput (ModelOutput ):
25+ sequences : torch .LongTensor = None
26+ scores : Optional [Tuple [torch .FloatTensor ]] = None
27+ logits : Optional [Tuple [torch .FloatTensor ]] = None
28+ attentions : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
29+ hidden_states : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
30+ past_key_values : Optional [Tuple [Tuple [Tuple [torch .FloatTensor ]]]] = None
31+
32+
33+ class GenerateEncoderDecoderOutput (ModelOutput ):
34+ sequences : torch .LongTensor = None
35+ scores : Optional [Tuple [torch .FloatTensor ]] = None
36+ logits : Optional [Tuple [torch .FloatTensor ]] = None
37+ encoder_attentions : Optional [Tuple [torch .FloatTensor ]] = None
38+ encoder_hidden_states : Optional [Tuple [torch .FloatTensor ]] = None
39+ decoder_attentions : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
40+ cross_attentions : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
41+ decoder_hidden_states : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
42+ past_key_values : Optional [Tuple [Tuple [Tuple [torch .FloatTensor ]]]] = None
43+
44+
45+ GenerateNonBeamOutput = Union [GenerateDecoderOnlyOutput , GenerateEncoderDecoderOutput ]
46+
47+
48+ # Transformers version: v4.43-release 868d36d29ec132deeaaf8571b25b6a1b911d0145
49+ # Copied from https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/generation/utils.py#L2841
50+ # Add the function of collecting latency
51+ def new_sample (
52+ self ,
53+ input_ids : torch .LongTensor ,
54+ logits_processor : LogitsProcessorList ,
55+ stopping_criteria : StoppingCriteriaList ,
56+ generation_config : GenerationConfig ,
57+ synced_gpus : bool ,
58+ streamer : Optional ["BaseStreamer" ],
59+ logits_warper : Optional [LogitsProcessorList ],
60+ ** model_kwargs ,
61+ ) -> Union [GenerateNonBeamOutput , torch .LongTensor ]:
62+ r"""
63+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
64+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
65+
66+ Parameters:
67+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
68+ The sequence used as a prompt for the generation.
69+ logits_processor (`LogitsProcessorList`):
70+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
71+ used to modify the prediction scores of the language modeling head applied at each generation step.
72+ stopping_criteria (`StoppingCriteriaList`):
73+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
74+ used to tell if the generation loop should stop.
75+ generation_config ([`~generation.GenerationConfig`]):
76+ The generation configuration to be used as parametrization of the decoding method.
77+ synced_gpus (`bool`):
78+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
79+ streamer (`BaseStreamer`, *optional*):
80+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
81+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
82+ logits_warper (`LogitsProcessorList`, *optional*):
83+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
84+ to warp the prediction score distribution of the language modeling head applied before multinomial
85+ sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
86+ `generation_config`)
87+ model_kwargs:
88+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
89+ an encoder-decoder model the kwargs should include `encoder_outputs`.
90+
91+ Return:
92+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
93+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
94+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
95+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
96+ `model.config.is_encoder_decoder=True`.
97+ """
98+ # init values
99+ pad_token_id = generation_config ._pad_token_tensor
100+ output_attentions = generation_config .output_attentions
101+ output_hidden_states = generation_config .output_hidden_states
102+ output_scores = generation_config .output_scores
103+ output_logits = generation_config .output_logits
104+ return_dict_in_generate = generation_config .return_dict_in_generate
105+ has_eos_stopping_criteria = any (hasattr (criteria , "eos_token_id" ) for criteria in stopping_criteria )
106+ do_sample = generation_config .do_sample
107+ if do_sample is True and not isinstance (logits_warper , LogitsProcessorList ):
108+ raise ValueError (
109+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
110+ f"{ logits_warper } )."
111+ )
112+
113+ # init attention / hidden states / scores tuples
114+ scores = () if (return_dict_in_generate and output_scores ) else None
115+ raw_logits = () if (return_dict_in_generate and output_logits ) else None
116+ decoder_attentions = () if (return_dict_in_generate and output_attentions ) else None
117+ cross_attentions = () if (return_dict_in_generate and output_attentions ) else None
118+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states ) else None
119+
120+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
121+ if return_dict_in_generate and self .config .is_encoder_decoder :
122+ encoder_attentions = model_kwargs ["encoder_outputs" ].get ("attentions" ) if output_attentions else None
123+ encoder_hidden_states = (
124+ model_kwargs ["encoder_outputs" ].get ("hidden_states" ) if output_hidden_states else None
125+ )
126+
127+ # keep track of which sequences are already finished
128+ batch_size = input_ids .shape [0 ]
129+ this_peer_finished = False
130+ unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
131+ model_kwargs = self ._get_initial_cache_position (input_ids , model_kwargs )
132+
133+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
134+ tic = time .perf_counter ()
135+ # prepare model inputs
136+ model_inputs = self .prepare_inputs_for_generation (input_ids , ** model_kwargs )
137+
138+ # prepare variable output controls (note: some models won't accept all output controls)
139+ model_inputs .update ({"output_attentions" : output_attentions } if output_attentions else {})
140+ model_inputs .update ({"output_hidden_states" : output_hidden_states } if output_hidden_states else {})
141+
142+ # forward pass to get next token
143+ tic_infer = time .perf_counter ()
144+ outputs = self (** model_inputs , return_dict = True )
145+ hook_greedy .tm_infer_list .append (time .perf_counter () - tic_infer )
146+
147+ if synced_gpus and this_peer_finished :
148+ continue # don't waste resources running the code we don't need
149+
150+ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
151+ # (the clone itself is always small)
152+ next_token_logits = outputs .logits [:, - 1 , :].clone ()
153+
154+ # pre-process distribution
155+ next_token_scores = logits_processor (input_ids , next_token_logits )
156+ if do_sample :
157+ next_token_scores = logits_warper (input_ids , next_token_scores )
158+
159+ # Store scores, attentions and hidden_states when required
160+ if return_dict_in_generate :
161+ if output_scores :
162+ scores += (next_token_scores ,)
163+ if output_logits :
164+ raw_logits += (next_token_logits ,)
165+ if output_attentions :
166+ decoder_attentions += (
167+ (outputs .decoder_attentions ,) if self .config .is_encoder_decoder else (outputs .attentions ,)
168+ )
169+ if self .config .is_encoder_decoder :
170+ cross_attentions += (outputs .cross_attentions ,)
171+
172+ if output_hidden_states :
173+ decoder_hidden_states += (
174+ (outputs .decoder_hidden_states ,)
175+ if self .config .is_encoder_decoder
176+ else (outputs .hidden_states ,)
177+ )
178+
179+ # token selection
180+ if do_sample :
181+ probs = nn .functional .softmax (next_token_scores , dim = - 1 )
182+ next_tokens = torch .multinomial (probs , num_samples = 1 ).squeeze (1 )
183+ else :
184+ next_tokens = torch .argmax (next_token_scores , dim = - 1 )
185+
186+ # finished sentences should have their next token be a padding token
187+ if has_eos_stopping_criteria :
188+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences )
189+
190+ # update generated ids, model inputs, and length for next step
191+ input_ids = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
192+ if streamer is not None :
193+ streamer .put (next_tokens .cpu ())
194+ model_kwargs = self ._update_model_kwargs_for_generation (
195+ outputs ,
196+ model_kwargs ,
197+ is_encoder_decoder = self .config .is_encoder_decoder ,
198+ )
199+
200+ unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
201+ this_peer_finished = unfinished_sequences .max () == 0
202+ hook_greedy .tm_list .append (time .perf_counter () - tic )
203+ # This is needed to properly delete outputs.logits which may be very large for first iteration
204+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
205+ del outputs
206+
207+ if streamer is not None :
208+ streamer .end ()
209+
210+ if return_dict_in_generate :
211+ if self .config .is_encoder_decoder :
212+ return GenerateEncoderDecoderOutput (
213+ sequences = input_ids ,
214+ scores = scores ,
215+ logits = raw_logits ,
216+ encoder_attentions = encoder_attentions ,
217+ encoder_hidden_states = encoder_hidden_states ,
218+ decoder_attentions = decoder_attentions ,
219+ cross_attentions = cross_attentions ,
220+ decoder_hidden_states = decoder_hidden_states ,
221+ past_key_values = model_kwargs .get ("past_key_values" ),
222+ )
223+ else :
224+ return GenerateDecoderOnlyOutput (
225+ sequences = input_ids ,
226+ scores = scores ,
227+ logits = raw_logits ,
228+ attentions = decoder_attentions ,
229+ hidden_states = decoder_hidden_states ,
230+ past_key_values = model_kwargs .get ("past_key_values" ),
231+ )
232+ else :
233+ return input_ids
0 commit comments