1+ from typing import List , Tuple
2+ from diffusers .models .autoencoders .autoencoder_kl import AutoencoderKL
3+ from invokeai .backend .bria .controlnet_bria import BriaControlModes , BriaMultiControlNetModel
4+ from invokeai .backend .bria .controlnet_utils import prepare_control_images
5+ from invokeai .nodes .bria_nodes .bria_controlnet import BriaControlNetField
6+
17import torch
28from diffusers .schedulers .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
39
4- from invokeai .app .invocations .fields import Input , InputField
5- from invokeai .app .invocations .model import SubModelType , TransformerField
6- from invokeai .app .invocations .primitives import (
7- BaseInvocationOutput ,
8- FieldDescriptions ,
9- Input ,
10- InputField ,
11- LatentsField ,
12- OutputField ,
13- )
10+ from invokeai .app .invocations .fields import Input , InputField , LatentsField , OutputField
11+ from invokeai .app .invocations .model import SubModelType , TransformerField , VAEField
12+ from invokeai .app .invocations .primitives import BaseInvocationOutput , FieldDescriptions
1413from invokeai .app .services .shared .invocation_context import InvocationContext
1514from invokeai .invocation_api import BaseInvocation , Classification , InputField , invocation , invocation_output
1615
@@ -43,6 +42,11 @@ class BriaDenoiseInvocation(BaseInvocation):
4342 input = Input .Connection ,
4443 title = "Transformer" ,
4544 )
45+ vae : VAEField = InputField (
46+ description = FieldDescriptions .vae ,
47+ input = Input .Connection ,
48+ title = "VAE" ,
49+ )
4650 latents : LatentsField = InputField (
4751 description = "Latents to denoise" ,
4852 input = Input .Connection ,
@@ -68,6 +72,12 @@ class BriaDenoiseInvocation(BaseInvocation):
6872 input = Input .Connection ,
6973 title = "Text IDs" ,
7074 )
75+ control : BriaControlNetField | list [BriaControlNetField ] | None = InputField (
76+ description = "ControlNet" ,
77+ input = Input .Connection ,
78+ title = "ControlNet" ,
79+ default = None ,
80+ )
7181
7282 @torch .no_grad ()
7383 def invoke (self , context : InvocationContext ) -> BriaDenoiseInvocationOutput :
@@ -83,16 +93,28 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
8393 with (
8494 context .models .load (self .transformer .transformer ) as transformer ,
8595 context .models .load (scheduler_identifier ) as scheduler ,
96+ context .models .load (self .vae .vae ) as vae ,
8697 ):
8798 assert isinstance (transformer , BriaTransformer2DModel )
8899 assert isinstance (scheduler , FlowMatchEulerDiscreteScheduler )
100+ assert isinstance (vae , AutoencoderKL )
89101 dtype = transformer .dtype
90102 device = transformer .device
91103 latents , pos_embeds , neg_embeds = map (lambda x : x .to (device , dtype ), (latents , pos_embeds , neg_embeds ))
92104 prompt_embeds = torch .cat ([neg_embeds , pos_embeds ]) if self .guidance_scale > 1 else pos_embeds
93105
94106 sigmas = get_original_sigmas (1000 , self .num_steps )
95107 timesteps , _ = retrieve_timesteps (scheduler , self .num_steps , device , None , sigmas , mu = 0.0 )
108+ width , height = 1024 , 1024
109+ if self .control is not None :
110+ control_model , control_images , control_modes , control_scales = self ._prepare_multi_control (
111+ context = context ,
112+ vae = vae ,
113+ width = width ,
114+ height = height ,
115+ device = device ,
116+
117+ )
96118
97119 for t in timesteps :
98120 # Prepare model input efficiently
@@ -101,11 +123,21 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
101123 else :
102124 latent_model_input = latents
103125
104- # Prepare timestep tensor efficiently
105- if isinstance (t , torch .Tensor ):
106- timestep_tensor = t .expand (latent_model_input .shape [0 ])
107- else :
108- timestep_tensor = torch .tensor ([t ] * latent_model_input .shape [0 ], device = device , dtype = torch .float32 )
126+ timestep_tensor = t .expand (latent_model_input .shape [0 ])
127+
128+ controlnet_block_samples , controlnet_single_block_samples = None , None
129+ if self .control is not None :
130+ controlnet_block_samples , controlnet_single_block_samples = control_model (
131+ hidden_states = latents ,
132+ controlnet_cond = control_images , # type: ignore
133+ controlnet_mode = control_modes , # type: ignore
134+ conditioning_scale = control_scales , # type: ignore
135+ timestep = timestep_tensor ,
136+ encoder_hidden_states = prompt_embeds ,
137+ txt_ids = text_ids ,
138+ img_ids = latent_image_ids ,
139+ return_dict = False ,
140+ )
109141
110142 noise_pred = transformer (
111143 latent_model_input ,
@@ -115,6 +147,8 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
115147 txt_ids = text_ids ,
116148 guidance = None ,
117149 return_dict = False ,
150+ controlnet_block_samples = controlnet_block_samples ,
151+ controlnet_single_block_samples = controlnet_single_block_samples ,
118152 )[0 ]
119153
120154 if self .guidance_scale > 1 :
@@ -131,3 +165,35 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
131165 saved_input_latents_tensor = context .tensors .save (latents )
132166 latents_output = LatentsField (latents_name = saved_input_latents_tensor )
133167 return BriaDenoiseInvocationOutput (latents = latents_output )
168+
169+
170+
171+ def _prepare_multi_control (
172+ self ,
173+ context : InvocationContext ,
174+ vae : AutoencoderKL ,
175+ width : int ,
176+ height : int ,
177+ device : torch .device
178+ ) -> Tuple [BriaMultiControlNetModel , List [torch .Tensor ], List [torch .Tensor ], List [float ]]:
179+
180+ control = self .control if isinstance (self .control , list ) else [self .control ]
181+ control_images , control_models , control_modes , control_scales = [], [], [], []
182+ for controlnet in control :
183+ if controlnet is not None :
184+ control_models .append (context .models .load (controlnet .model ).model )
185+ control_images .append (context .images .get_pil (controlnet .image .image_name ))
186+ control_modes .append (BriaControlModes [controlnet .mode ].value )
187+ control_scales .append (controlnet .conditioning_scale )
188+
189+ control_model = BriaMultiControlNetModel (control_models ).to (device )
190+ tensored_control_images , tensored_control_modes = prepare_control_images (
191+ vae = vae ,
192+ control_images = control_images ,
193+ control_modes = control_modes ,
194+ width = width ,
195+ height = height ,
196+ device = device ,
197+ )
198+ return control_model , tensored_control_images , tensored_control_modes , control_scales
199+
0 commit comments