Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/diffusers/quantizers/gguf/gguf_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_dequantize_gguf_and_restore_linear,
_quant_shape_from_byte_shape,
_replace_with_gguf_linear,
dequantize_gguf_tensor,
)


Expand Down Expand Up @@ -116,6 +117,17 @@ def create_quantized_param(
if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")

# If the GGUFParameter should not be quantized (for example, it is a submodule of any excluded module),
# dequantize it and set the (dequantized) parameter to the proper dtype.
if isinstance(param_value, GGUFParameter) and any(
m in param_name.split(".") for m in self.modules_to_not_convert
):
keep_in_fp32 = getattr(self, "keep_in_fp32_modules", [])
target_dtype = (
torch.float32 if any(m in param_name.split(".") for m in keep_in_fp32) else self.compute_dtype
)
param_value = dequantize_gguf_tensor(param_value).to(target_dtype)
Comment on lines +120 to +129
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused. If a param is already GGUFParameter type, then I'd assume that it's already quantized. In that case, how come dequantize -> type upcasting is the right sequence of ops?

What am I missing?

Copy link
Copy Markdown
Collaborator Author

@dg845 dg845 May 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that the GGUF checkpoint might specify a quantization for a parameter that we do not want to be quantized, as expressed through either _keep_in_fp32_modules on the model: ModelMixin or modules_to_not_convert on GGUFQuantizationConfig.

When we load the GGUF state dict, these parameters will be placed into a GGUFParameter, and this happens before we load the weights into the model (e.g. in FromOriginalModelMixin.from_single_file). To respect modules_to_not_convert, we need to convert these back into normal (unquantized) parameters, which we do here at load time via dequantize_gguf_tensor. We then need to cast the parameter to the right compute dtype, which is torch.float32 for keep_in_fp32_modules and compute_dtype otherwise.

Currently, GGUFQuantizationConfig doesn't expose a modules_to_not_convert argument, but keep_in_fp32_modules are included in modules_to_not_convert:

self.modules_to_not_convert.extend(keep_in_fp32_modules)

So this change would affect only any specified _keep_in_fp32_modules right now.


if tensor_name in module._parameters:
module._parameters[tensor_name] = param_value.to(target_device)
if tensor_name in module._buffers:
Expand All @@ -130,7 +142,8 @@ def _process_model_before_weight_loading(
):
state_dict = kwargs.get("state_dict", None)

self.modules_to_not_convert.extend(keep_in_fp32_modules)
self.keep_in_fp32_modules = [module for module in keep_in_fp32_modules if module is not None]
self.modules_to_not_convert.extend(self.keep_in_fp32_modules)
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]

_replace_with_gguf_linear(
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/quantizers/gguf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: in
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
weight = dequantize_gguf_tensor(qweight)
return x @ weight.T
return x @ weight.to(x.dtype).T
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it break torch.compile compatibility for models that don't define modules_to_not_convert / keep_in_fp32_modules?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how it will interact with torch.compile, but this change mirrors the implementation used for quantized weight types (qweight_type in DEQUANT_TYPES):

weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.to(x.dtype).T

So I think it should be fine? (I think this change isn't specific to modules_to_not_convert, as the GGUF checkpoint could store weights in e.g. BF16 even if modules_to_not_convert is empty, which would then go through this code path.)


# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
Expand Down Expand Up @@ -134,6 +134,8 @@ def _should_convert_to_gguf(state_dict, prefix):
return

for name, module in model.named_children():
if name in modules_to_not_convert:
continue
module_prefix = prefix + name + "."
_replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert)

Expand Down
Loading