@@ -52,6 +52,7 @@ def __init__(
5252 adm_in_channels = None ,
5353 transformer_depth_middle = None ,
5454 transformer_depth_output = None ,
55+ attn_precision = None ,
5556 device = None ,
5657 operations = comfy .ops .disable_weight_init ,
5758 ** kwargs ,
@@ -202,7 +203,7 @@ def __init__(
202203 SpatialTransformer (
203204 ch , num_heads , dim_head , depth = num_transformers , context_dim = context_dim ,
204205 disable_self_attn = disabled_sa , use_linear = use_linear_in_transformer ,
205- use_checkpoint = use_checkpoint , dtype = self .dtype , device = device , operations = operations
206+ use_checkpoint = use_checkpoint , attn_precision = attn_precision , dtype = self .dtype , device = device , operations = operations
206207 )
207208 )
208209 self .input_blocks .append (TimestepEmbedSequential (* layers ))
@@ -262,7 +263,7 @@ def __init__(
262263 mid_block += [SpatialTransformer ( # always uses a self-attn
263264 ch , num_heads , dim_head , depth = transformer_depth_middle , context_dim = context_dim ,
264265 disable_self_attn = disable_middle_self_attn , use_linear = use_linear_in_transformer ,
265- use_checkpoint = use_checkpoint , dtype = self .dtype , device = device , operations = operations
266+ use_checkpoint = use_checkpoint , attn_precision = attn_precision , dtype = self .dtype , device = device , operations = operations
266267 ),
267268 ResBlock (
268269 ch ,
0 commit comments