Skip to content

torchao >= 0.16.0 quantization not supported #13286

@zzlol63

Description

@zzlol63

Describe the bug

Below sample code (taken from https://huggingface.co/blog/lora-fast) does not work because torchao has renamed the APIs and mentions it as a breaking change in 0.15.0 (with deprecation warning) and above as per the release notes:
https://github.com/pytorch/ao/releases/tag/v0.15.0

Before:

from torchao.quantization import (
    float8_dynamic_activation_float8_weight,
    float8_static_activation_float8_weight,
    float8_weight_only,
    fpx_weight_only,
    gemlite_uintx_weight_only,
    int4_dynamic_activation_int4_weight,
    int4_weight_only,
    int8_dynamic_activation_int4_weight,
    int8_dynamic_activation_int8_weight,
    int8_weight_only,
    quantize_,
    uintx_weight_only,
)

After:

from torchao.quantization import (
    Float8DynamicActivationFloat8WeightConfig,
    Float8StaticActivationFloat8WeightConfig,
    Float8WeightOnlyConfig,
    FPXWeightOnlyConfig,
    GemliteUIntXWeightOnlyConfig,
    Int4DynamicActivationInt4WeightConfig,
    Int4WeightOnlyConfig,
    Int8DynamicActivationInt4WeightConfig,
    Int8DynamicActivationInt8WeightConfig,
    Int8WeightOnlyConfig,
    quantize_,
    UIntXWeightOnlyConfig,
)

In 0.16.0 its completely removed.

Reproduction

from diffusers import DiffusionPipeline, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
import torch

# quantize the Flux transformer with FP8
pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
    quantization_config=PipelineQuantizationConfig(
        quant_mapping={"transformer": TorchAoConfig("float8dq_e4m3_row")}
    )
).to("cuda")

# use torch.compile()
pipe.transformer.compile(fullgraph=True, mode="max-autotune")

# perform inference
pipe_kwargs = {
    "prompt": "A cat holding a sign that says hello world",
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 28,
    "max_sequence_length": 512,
}

# first time will be slower, subsequent runs will be faster
image = pipe(**pipe_kwargs).images[0]

Logs

File "C:\test\test.py", line 485, in main
    quant_mapping={"transformer": TorchAoConfig("float8dq_e4m3_row")}
  File "C:\Users\Home\anaconda3\envs\test\lib\site-packages\diffusers\quantizers\quantization_config.py", line 517, in __init__
    self.post_init()
  File "C:\Users\Home\anaconda3\envs\test\lib\site-packages\diffusers\quantizers\quantization_config.py", line 533, in post_init
    TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
  File "C:\Users\Home\anaconda3\envs\test\lib\site-packages\diffusers\quantizers\quantization_config.py", line 629, in _get_torchao_quant_type_to_method
    from torchao.quantization import (
ImportError: cannot import name 'float8_dynamic_activation_float8_weight' from 'torchao.quantization'

System Info

  • 🤗 Diffusers version: 0.37.0
  • Platform: Windows-10-10.0.26200-SP0
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.10.0+cu130 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 1.7.1
  • Transformers version: 5.3.0
  • Accelerate version: 1.10.1
  • PEFT version: 0.18.1
  • Bitsandbytes version: 0.49.2
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions