Commit 9da9b68
fix: disable non-blocking tensor copies to MPS during model loading
When loading model weights with `device_map="mps"`, `load_model_dict_into_meta`
unconditionally passes `non_blocking=True` to `set_module_tensor_to_device`
(accelerate > 1.8.1). With mmap-backed safetensors the source CPU memory can
be released before the asynchronous MPS copy completes, silently corrupting
the destination weights.
The corruption is non-deterministic and dtype-dependent (float32 corrupts
weights but not biases; float16 corrupts biases but not weights), producing
extreme values (~1e37), LayerNorm overflow, and NaN outputs.
Move the `non_blocking` / `clear_cache` assignment after `param_device` is
resolved, and force `non_blocking=False` when the target is MPS.
Fixes #13227
Made-with: Cursor1 parent c02c17c commit 9da9b68
1 file changed
Lines changed: 11 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
253 | 253 | | |
254 | 254 | | |
255 | 255 | | |
256 | | - | |
257 | | - | |
258 | | - | |
259 | | - | |
260 | 256 | | |
261 | 257 | | |
262 | 258 | | |
| |||
277 | 273 | | |
278 | 274 | | |
279 | 275 | | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
280 | 287 | | |
281 | 288 | | |
282 | 289 | | |
| |||
0 commit comments