Skip to content

Commit a6cfdf7

Browse files
Prakhar AgarwalPrakhar Agarwal
authored andcommitted
fix: disable non-blocking tensor copies to MPS during model loading
When loading model weights with `device_map="mps"`, the `non_blocking=True` parameter in `set_module_tensor_to_device` causes asynchronous CPU-to-MPS copies. With mmap-backed safetensors the source CPU memory can be released before the MPS copy completes, silently corrupting the destination weights. The corruption is non-deterministic and dtype-dependent (float32 corrupts weights but not biases; float16 corrupts biases but not weights), resulting in extreme values (~1e37), LayerNorm overflow, and NaN outputs. Force synchronous copies (`non_blocking=False`) when the target device is MPS. Other devices (CUDA, CPU) continue to use non-blocking transfers. Fixes #13227 Made-with: Cursor
1 parent c02c17c commit a6cfdf7

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

src/diffusers/models/model_loading_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,14 @@ def load_model_dict_into_meta(
254254
set_module_kwargs["dtype"] = dtype
255255

256256
if is_accelerate_version(">", "1.8.1"):
257-
set_module_kwargs["non_blocking"] = True
257+
# MPS does not support truly asynchronous non-blocking transfers from CPU.
258+
# When non_blocking=True the source tensor may be freed or recycled (especially
259+
# with mmap-backed safetensors) before the MPS copy completes, silently corrupting
260+
# the destination weights. Force synchronous copies on MPS to avoid this.
261+
is_mps_target = str(param_device) == "mps" or (
262+
isinstance(param_device, torch.device) and param_device.type == "mps"
263+
)
264+
set_module_kwargs["non_blocking"] = not is_mps_target
258265
set_module_kwargs["clear_cache"] = False
259266

260267
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which

0 commit comments

Comments
 (0)