|
212 | 212 | # limitations under the License. |
213 | 213 | from typing import Callable, Optional |
214 | 214 | import torch |
215 | | -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection |
| 215 | +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, \ |
| 216 | + CLIPVisionModelWithProjection, CLIPImageProcessor |
216 | 217 | from accelerate.logging import get_logger |
217 | 218 |
|
218 | 219 | from diffusers.models import AutoencoderKL, UNet2DConditionModel |
@@ -551,22 +552,36 @@ def __init__( |
551 | 552 | tokenizer_2: CLIPTokenizer, |
552 | 553 | unet: UNet2DConditionModel, |
553 | 554 | scheduler: KarrasDiffusionSchedulers, |
| 555 | + image_encoder: CLIPVisionModelWithProjection = None, |
| 556 | + feature_extractor: CLIPImageProcessor = None, |
554 | 557 | force_zeros_for_empty_prompt: bool = True, |
555 | 558 | add_watermarker: Optional[bool] = None, |
556 | 559 | modifier_token: list = [], |
557 | 560 | modifier_token_id: list = [], |
558 | 561 | modifier_token_id_2: list = [] |
559 | 562 | ): |
560 | | - super().__init__(vae=vae, |
561 | | - text_encoder=text_encoder, |
562 | | - text_encoder_2=text_encoder_2, |
563 | | - tokenizer=tokenizer, |
564 | | - tokenizer_2=tokenizer_2, |
565 | | - unet=unet, |
566 | | - scheduler=scheduler, |
| 563 | + super().__init__(vae, |
| 564 | + text_encoder, |
| 565 | + text_encoder_2, |
| 566 | + tokenizer, |
| 567 | + tokenizer_2, |
| 568 | + unet, |
| 569 | + scheduler, |
| 570 | + image_encoder=image_encoder, |
| 571 | + feature_extractor=feature_extractor, |
567 | 572 | force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, |
568 | 573 | add_watermarker=add_watermarker, |
569 | 574 | ) |
| 575 | + # super().__init__(vae, |
| 576 | + # text_encoder, |
| 577 | + # text_encoder_2, |
| 578 | + # tokenizer, |
| 579 | + # tokenizer_2, |
| 580 | + # unet, |
| 581 | + # scheduler, |
| 582 | + # force_zeros_for_empty_prompt, |
| 583 | + # add_watermarker, |
| 584 | + # ) |
570 | 585 |
|
571 | 586 | # change attn class |
572 | 587 | self.modifier_token = modifier_token |
|
0 commit comments