-
Notifications
You must be signed in to change notification settings - Fork 6.5k
GGUF: torch.compile cannot trace sets #12556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
torch.compile fails to trace
```
s = { ... }
if x in s:
```
it can trace lists
| UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16} | ||
| STANDARD_QUANT_TYPES = { | ||
| UNQUANTIZED_TYPES = [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16] | ||
| STANDARD_QUANT_TYPES = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we use set operations on them like here
| DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES |
would this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, this would not work and another solution is required
|
thanks for the PR! could you provide a testing script? |
|
|
@dxqb Could you provide an example where compile is failing with GGUF? I'm able to run this snippet without any issues? import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
model_id = "black-forest-labs/FLUX.1-dev"
ckpt_path = (
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
transformer = torch.compile(transformer, fullgraph=True)
pipe = FluxPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
images = pipe("A cat holding a sign that says hello").images[0]
images.save("flux-gguf-compile.png") |
torch.compile fails to trace
it can trace lists
Who can review?
@Isotr0py