Skip to content

Commit 9da9b68

Browse files
Prakhar AgarwalPrakhar Agarwal
authored andcommitted
fix: disable non-blocking tensor copies to MPS during model loading
When loading model weights with `device_map="mps"`, `load_model_dict_into_meta` unconditionally passes `non_blocking=True` to `set_module_tensor_to_device` (accelerate > 1.8.1). With mmap-backed safetensors the source CPU memory can be released before the asynchronous MPS copy completes, silently corrupting the destination weights. The corruption is non-deterministic and dtype-dependent (float32 corrupts weights but not biases; float16 corrupts biases but not weights), producing extreme values (~1e37), LayerNorm overflow, and NaN outputs. Move the `non_blocking` / `clear_cache` assignment after `param_device` is resolved, and force `non_blocking=False` when the target is MPS. Fixes #13227 Made-with: Cursor
1 parent c02c17c commit 9da9b68

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

src/diffusers/models/model_loading_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,6 @@ def load_model_dict_into_meta(
253253
param = param.to(dtype)
254254
set_module_kwargs["dtype"] = dtype
255255

256-
if is_accelerate_version(">", "1.8.1"):
257-
set_module_kwargs["non_blocking"] = True
258-
set_module_kwargs["clear_cache"] = False
259-
260256
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
261257
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
262258
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -277,6 +273,17 @@ def load_model_dict_into_meta(
277273

278274
param_device = _determine_param_device(param_name, device_map)
279275

276+
if is_accelerate_version(">", "1.8.1"):
277+
# MPS does not support truly asynchronous non-blocking transfers from CPU.
278+
# When non_blocking=True the source tensor may be freed or recycled (especially
279+
# with mmap-backed safetensors) before the MPS copy completes, silently corrupting
280+
# the destination weights. Force synchronous copies on MPS to avoid this.
281+
is_mps_target = str(param_device) == "mps" or (
282+
isinstance(param_device, torch.device) and param_device.type == "mps"
283+
)
284+
set_module_kwargs["non_blocking"] = not is_mps_target
285+
set_module_kwargs["clear_cache"] = False
286+
280287
# bnb params are flattened.
281288
# gguf quants have a different shape based on the type of quantization applied
282289
if empty_state_dict[param_name].shape != param.shape:

0 commit comments

Comments
 (0)