System Info
transformers version: 5.3.0
- Platform: Linux-6.17.0-19-generic-x86_64-with-glibc2.39
- Python version: 3.12.4
- Huggingface_hub version: 1.7.2
- Safetensors version: 0.4.5
- Accelerate version: not installed
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.5.1 (NA)
- Using distributed or parallel set-up in script?: No
When creating a GPTNeoXConfig with non-default rotary_pct, value is lost after save_pretrained / from_pretrained.
Cause
|
self.rope_parameters["partial_rotary_factor"] = kwargs.pop("rotary_pct", 0.25) |
save_pretrained writes partial_rotary_factor inside rope_parameters but does not persist rotary_pct as a top-level key. On reload, rotary_pct is absent from kwargs, so this line unconditionally overwrites the correct value with 0.25.
Fix
rotary_pct = kwargs.pop("rotary_pct", None)
if rotary_pct is not None:
self.rope_parameters["partial_rotary_factor"] = rotary_pct
else:
self.rope_parameters.setdefault("partial_rotary_factor", 0.25)
Verified locally after applying this, the value survives the round-trip.
Models using the default rotary_pct=0.25 (gpt-neox-20b, Pythia, etc.) are unaffected since the overwrite produces the same value.
Who can help?
@ArthurZucker @Cyrilvallez
Information
Tasks
Reproduction
from transformers import GPTNeoXConfig
config = GPTNeoXConfig(rotary_pct=1.0)
print(config.rope_parameters["partial_rotary_factor"]) # 1.0
config.save_pretrained("/tmp/test")
config2 = GPTNeoXConfig.from_pretrained("/tmp/test")
print(config2.rope_parameters["partial_rotary_factor"]) # 0.25
Expected behavior
partial_rotary_factor value should be retained
System Info
transformersversion: 5.3.0When creating a
GPTNeoXConfigwith non-defaultrotary_pct, value is lost aftersave_pretrained/from_pretrained.Cause
transformers/src/transformers/models/gpt_neox/configuration_gpt_neox.py
Line 98 in 3a3b59c
save_pretrainedwritespartial_rotary_factorinsiderope_parametersbut does not persistrotary_pctas a top-level key. On reload,rotary_pctis absent from kwargs, so this line unconditionally overwrites the correct value with0.25.Fix
Verified locally after applying this, the value survives the round-trip.
Models using the default
rotary_pct=0.25(gpt-neox-20b, Pythia, etc.) are unaffected since the overwrite produces the same value.Who can help?
@ArthurZucker @Cyrilvallez
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
partial_rotary_factorvalue should be retained