File tree Expand file tree Collapse file tree 4 files changed +107
-61
lines changed
server/text_generation_server/models Expand file tree Collapse file tree 4 files changed +107
-61
lines changed Original file line number Diff line number Diff line change @@ -710,34 +710,41 @@ def forward(
710710 # )
711711 if SYSTEM == "ipex" :
712712 attn_output = torch .empty_like (query_states )
713- ipex .llm .functional .varlen_attention (
714- (
715- query_states .contiguous ()
716- if query_states .device .type == "xpu"
717- else query_states
718- ),
719- (
720- key_states .contiguous ()
721- if key_states .device .type == "xpu"
722- else key_states
723- ),
724- (
725- value_states .contiguous ()
726- if value_states .device .type == "xpu"
727- else value_states
728- ),
729- attn_output ,
730- cu_seqlen_q ,
731- cu_seqlen_k ,
732- max_q ,
733- max_k ,
734- 0.0 ,
735- self .softmax_scale ,
736- False ,
737- causal ,
738- False ,
739- None ,
740- )
713+ if query_states .device .type == "xpu" :
714+ ipex .llm .functional .varlen_attention (
715+ query_states .contiguous (),
716+ key_states .contiguous (),
717+ value_states .contiguous (),
718+ attn_output ,
719+ cu_seqlen_q ,
720+ cu_seqlen_k ,
721+ None ,
722+ max_q ,
723+ max_k ,
724+ 0.0 ,
725+ self .softmax_scale ,
726+ False ,
727+ causal ,
728+ False ,
729+ None ,
730+ )
731+ else :
732+ ipex .llm .functional .varlen_attention (
733+ query_states ,
734+ key_states ,
735+ value_states ,
736+ attn_output ,
737+ cu_seqlen_q ,
738+ cu_seqlen_k ,
739+ max_q ,
740+ max_k ,
741+ 0.0 ,
742+ self .softmax_scale ,
743+ False ,
744+ causal ,
745+ False ,
746+ None ,
747+ )
741748 else :
742749 attn_output = flash_attn_2_cuda .varlen_fwd (
743750 query_states ,
Original file line number Diff line number Diff line change @@ -460,22 +460,41 @@ def forward(
460460 # execute flash attention
461461 if SYSTEM == "ipex" :
462462 attn_output = torch .empty_like (query )
463- ipex .llm .functional .varlen_attention (
464- (query .contiguous () if query .device .type == "xpu" else query ),
465- (key .contiguous () if key .device .type == "xpu" else key ),
466- (value .contiguous () if value .device .type == "xpu" else value ),
467- attn_output ,
468- cu_seqlens ,
469- cu_seqlens ,
470- max_seqlen ,
471- max_seqlen ,
472- 0.0 ,
473- self .softmax_scale ,
474- False ,
475- causal ,
476- False ,
477- None ,
478- )
463+ if query .device .type == "xpu" :
464+ ipex .llm .functional .varlen_attention (
465+ query .contiguous (),
466+ key .contiguous (),
467+ value .contiguous (),
468+ attn_output ,
469+ cu_seqlens ,
470+ cu_seqlens ,
471+ None ,
472+ max_seqlen ,
473+ max_seqlen ,
474+ 0.0 ,
475+ self .softmax_scale ,
476+ False ,
477+ causal ,
478+ False ,
479+ None ,
480+ )
481+ else :
482+ ipex .llm .functional .varlen_attention (
483+ query ,
484+ key ,
485+ value ,
486+ attn_output ,
487+ cu_seqlens ,
488+ cu_seqlens ,
489+ max_seqlen ,
490+ max_seqlen ,
491+ 0.0 ,
492+ self .softmax_scale ,
493+ False ,
494+ causal ,
495+ False ,
496+ None ,
497+ )
479498 else :
480499 attn_output = flash_attn_2_cuda .varlen_fwd (
481500 query ,
Original file line number Diff line number Diff line change @@ -130,22 +130,41 @@ def forward(
130130 # execute flash attention
131131 if SYSTEM == "ipex" :
132132 attn_output = torch .empty_like (query )
133- ipex .llm .functional .varlen_attention (
134- (query .contiguous () if query .device .type == "xpu" else query ),
135- (key .contiguous () if key .device .type == "xpu" else key ),
136- (value .contiguous () if value .device .type == "xpu" else value ),
137- attn_output ,
138- cu_seqlens ,
139- cu_seqlens ,
140- max_seqlen ,
141- max_seqlen ,
142- 0.0 ,
143- self .softmax_scale ,
144- False ,
145- causal ,
146- False ,
147- None ,
148- )
133+ if query .device .type == "xpu" :
134+ ipex .llm .functional .varlen_attention (
135+ query .contiguous (),
136+ key .contiguous (),
137+ value .contiguous (),
138+ attn_output ,
139+ cu_seqlens ,
140+ cu_seqlens ,
141+ None ,
142+ max_seqlen ,
143+ max_seqlen ,
144+ 0.0 ,
145+ self .softmax_scale ,
146+ False ,
147+ causal ,
148+ False ,
149+ None ,
150+ )
151+ else :
152+ ipex .llm .functional .varlen_attention (
153+ query ,
154+ key ,
155+ value ,
156+ attn_output ,
157+ cu_seqlens ,
158+ cu_seqlens ,
159+ max_seqlen ,
160+ max_seqlen ,
161+ 0.0 ,
162+ self .softmax_scale ,
163+ False ,
164+ causal ,
165+ False ,
166+ None ,
167+ )
149168 else :
150169 attn_output = flash_attn_2_cuda .varlen_fwd (
151170 query ,
Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def concatenate(cls, batches):
5959 @tracer .start_as_current_span ("filter" )
6060 def filter (self , request_ids : List [int ]):
6161 assert self .image_indices is not None
62- batch = super ().filter (request_ids )
62+ batch = super (VlmCausalLMBatch , self ).filter (request_ids )
6363 assert self .image_indices is not None
6464 indices = []
6565 for i , request_id in enumerate (request_ids ):
@@ -85,6 +85,7 @@ def filter(self, request_ids: List[int]):
8585 ]
8686 else :
8787 batch .cross_attention_states = None
88+ batch .pixel_values = None
8889 return batch
8990
9091 @classmethod
You can’t perform that action at this time.
0 commit comments