Skip to content

Commit 411e388

Browse files
msaroufimfacebook-github-bot
authored andcommitted
Add HF Auth mixin to Stable Diffusion (#1763)
Summary: Right now stale diffusion and lit-llama are not actually running in CI because they get rate limited by huggingface. since we've now added an auth token as a github secret we can move stable diffusion out of canary and do things like include it in blueberries dashboard We also added some nice errors so people running in torchbench locally know they will need to have a token to run these models Anyways auth is a mixin which seems like the right abstraction # Some relevant details about the model Torchbench has a function `get_module()` that has the intent of testing a `nn.Module` on an actual `torch.Tensor` Unfortunately a `StableDiffusionPipeline` is not an `nn.Module` it's a composition of a tokenizer and 3 seperate `nn.Modules` an encoder, vae and unet. ## text_encoder ```python def get_module(self): batch_size = 1 sequence_length = 10 vocab_size = 32000 # Generate random indices within the valid range input_tensor = torch.randint(low=0, high=vocab_size, size=(batch_size, sequence_length)) # Make sure the tensor has the correct data type input_tensor = input_tensor.long() print(self.pipe.text_encoder(input_tensor)) return self.pipe.text_encoder, input_tensor ``` Text encoder outputs a `BaseModelOutputWithPooling` which has multiple nn modules https://gist.github.com/msaroufim/51f0038863c5cce4cc3045e4d9f9c399 ``` ====================================================================== FAIL: test_stable_diffusion_example_cuda (__main__.TestBenchmark) ---------------------------------------------------------------------- components._impl.workers.subprocess_rpc.ChildTraceException: Traceback (most recent call last): File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_rpc.py", line 482, in _run_block exec( # noqa: P204 File "<subprocess-worker>", line 35, in <module> File "<subprocess-worker>", line 12, in _run_in_worker_f File "/home/ubuntu/benchmark/torchbenchmark/util/model.py", line 26, in __call__ obj.__post__init__() File "/home/ubuntu/benchmark/torchbenchmark/util/model.py", line 126, in __post__init__ self.accuracy = check_accuracy(self) File "/home/ubuntu/benchmark/torchbenchmark/util/env_check.py", line 469, in check_accuracy model, example_inputs = maybe_cast(tbmodel, model, example_inputs) File "/home/ubuntu/benchmark/torchbenchmark/util/env_check.py", line 424, in maybe_cast example_inputs = clone_inputs(example_inputs) File "/home/ubuntu/benchmark/torchbenchmark/util/env_check.py", line 297, in clone_inputs assert isinstance(value, torch.Tensor) AssertionError ``` ## vae ```python def get_module(self): print(self.pipe.vae(torch.randn(9,3,9,9))) ``` Same problem for vae https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L27 ## unet ```python def get_module(self): # This will only benchmark the unet since that's the biggest layer # Stable diffusion is a composition of a text encoder, unet and vae encoder_hidden_states = torch.randn(320, 1024) sample = torch.randn(4, 4, 4, 32) timestep = 5 inputs_to_pipe = {'timestep': timestep, 'encoder_hidden_states': encoder_hidden_states, 'sample': sample} result = self.pipe.unet(**inputs_to_pipe) return self.pipe, inputs_to_pipe ``` Unet unfortunately does not have a tensor input For VAE and encoder the test failure is particularly helpful ``` (sam) ubuntu@ip-172-31-9-217:~/benchmark$ python test.py -k "test_stable_diffusion_example_cuda" F ====================================================================== FAIL: test_stable_diffusion_example_cuda (__main__.TestBenchmark) ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/ubuntu/benchmark/test.py", line 75, in example_fn assert accuracy == "pass" or accuracy == "eager_1st_run_OOM", f"Expected accuracy pass, get {accuracy}" AssertionError: Expected accuracy pass, get eager_1st_run_fail ---------------------------------------------------------------------- Ran 1 test in 7.402s FAILED (failures=1) ``` Pull Request resolved: #1763 Reviewed By: xuzhao9 Differential Revision: D47565523 Pulled By: msaroufim fbshipit-source-id: c949ce8a31c0a4706658937fc6603a22a4bc3ec6
1 parent 09de70c commit 411e388

File tree

8 files changed

+55
-20
lines changed

8 files changed

+55
-20
lines changed

.github/workflows/pr-a10g.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ env:
1010
CONDA_ENV: "torchbench"
1111
DOCKER_IMAGE: "ghcr.io/pytorch/torchbench:latest"
1212
SETUP_SCRIPT: "/workspace/setup_instance.sh"
13+
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
14+
1315

1416
jobs:
1517
pr-test:
@@ -36,8 +38,9 @@ jobs:
3638
- name: Install and Test TorchBench
3739
run: |
3840
container_name=$(docker run \
39-
-e CONDA_ENV \
40-
-e SETUP_SCRIPT \
41+
-e CONDA_ENV="${CONDA_ENV}" \
42+
-e SETUP_SCRIPT="${SETUP_SCRIPT}" \
43+
-e HUGGING_FACE_HUB_TOKEN="${HUGGING_FACE_HUB_TOKEN}" \
4144
--tty \
4245
--detach \
4346
--shm-size=32gb \

.github/workflows/pr-gha-runner.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ env:
1010
BASE_CONDA_ENV: "torchbench"
1111
CONDA_ENV: "pr-ci-a100"
1212
SETUP_SCRIPT: "/workspace/setup_instance.sh"
13+
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
1314

1415
jobs:
1516
pr-test:

torchbenchmark/canary_models/stable_diffusion/install.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

torchbenchmark/canary_models/stable_diffusion/__init__.py renamed to torchbenchmark/models/stable_diffusion/__init__.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
"""
66
from torchbenchmark.tasks import COMPUTER_VISION
77
from torchbenchmark.util.model import BenchmarkModel
8+
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin
89

910
import torch
1011
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
1112

1213

13-
class Model(BenchmarkModel):
14+
class Model(BenchmarkModel, HuggingFaceAuthMixin):
1415
task = COMPUTER_VISION.GENERATION
1516

1617
DEFAULT_TRAIN_BSIZE = 1
@@ -19,22 +20,32 @@ class Model(BenchmarkModel):
1920
# Default eval precision on CUDA device is fp16
2021
DEFAULT_EVAL_CUDA_PRECISION = "fp16"
2122

22-
2323
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
24+
HuggingFaceAuthMixin.__init__(self)
2425
super().__init__(test=test, device=device, jit=jit,
2526
batch_size=batch_size, extra_args=extra_args)
26-
assert self.dargs.precision == "fp16", f"Stable Diffusion model only supports fp16 precision."
2727
model_id = "stabilityai/stable-diffusion-2"
2828
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
29-
self.pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
29+
self.pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler)
3030
self.pipe.to(self.device)
3131
self.example_inputs = "a photo of an astronaut riding a horse on mars"
3232

3333
def enable_fp16_half(self):
3434
pass
3535

36+
3637
def get_module(self):
37-
return self.model, self.example_inputs
38+
batch_size = 1
39+
sequence_length = 10
40+
vocab_size = 32000
41+
42+
# Generate random indices within the valid range
43+
input_tensor = torch.randint(low=0, high=vocab_size, size=(batch_size, sequence_length))
44+
45+
# Make sure the tensor has the correct data type
46+
input_tensor = input_tensor.long().to(self.device)
47+
return self.pipe.text_encoder, [input_tensor]
48+
3849

3950
def train(self):
4051
raise NotImplementedError("Train test is not implemented for the stable diffusion model.")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from torchbenchmark.util.framework.diffusers import install_diffusers
2+
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin
3+
import torch
4+
import os
5+
import warnings
6+
MODEL_NAME = "stabilityai/stable-diffusion-2"
7+
8+
def load_model_checkpoint():
9+
from diffusers import StableDiffusionPipeline
10+
StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, safety_checker=None)
11+
12+
if __name__ == "__main__":
13+
if not 'HUGGING_FACE_HUB_TOKEN' in os.environ:
14+
warnings.warn("Make sure to set `HUGGINGFACE_HUB_TOKEN` so you can download weights")
15+
else:
16+
install_diffusers()
17+
load_model_checkpoint()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
devices:
2+
NVIDIA A100-SXM4-40GB:
3+
eval_batch_size: 32
4+
eval_benchmark: false
5+
eval_deterministic: false
6+
eval_nograd: true
7+
train_benchmark: false
8+
train_deterministic: false
9+
not_implemented:
10+
- device: cpu

torchbenchmark/util/framework/huggingface/model_factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def eval(self) -> Tuple[torch.Tensor]:
159159
else:
160160
return (out["logits"], )
161161

162+
class HuggingFaceAuthMixin:
163+
def __init__(self):
164+
if not 'HUGGING_FACE_HUB_TOKEN' in os.environ:
165+
raise NotImplementedError("Make sure to set `HUGGING_FACE_HUB_TOKEN` so you can download weights")
166+
162167

163168
class HuggingFaceGenerationModel(HuggingFaceModel):
164169
task = NLP.GENERATION

torchbenchmark/util/metadata_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def skip_by_metadata(test: str, device:str, jit: bool, extra_args: List[str], me
2020
match_item("extra_args", extra_args, skip_item)
2121
if match:
2222
return True
23-
return False
23+
return False

0 commit comments

Comments
 (0)