Skip to content

Commit eb60fc3

Browse files
Ubuntuclaude
andcommitted
fix: add weights_only=True to torch.load in GPU inference pipeline
Mitigate unsafe deserialization vulnerability (CWE-502) in the GPU inference pipeline. torch.load without weights_only=True allows arbitrary code execution via malicious pickle payloads in checkpoint files. Affected locations: - gpu/convert_checkpoint.py:37 (checkpoint conversion utility) - gpu/generate.py:67,69 (fp16 and int2 checkpoint loading) The utils/ scripts already applied this parameter correctly; this commit brings the GPU pipeline to the same safety standard. Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent 8fd3412 commit eb60fc3

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

gpu/convert_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def quant_weight_fp16(weight):
3434
def convert_int8_to_int2(weight):
3535
return convert_weight_int8_to_int2(weight)
3636

37-
merged_result = torch.load(input_path, map_location="cpu", mmap=True)
37+
merged_result = torch.load(input_path, map_location="cpu", mmap=True, weights_only=True)
3838
int2_result = {}
3939
fp16_result = {}
4040
zero = torch.zeros(1).to(torch.bfloat16)

gpu/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def build(
6464
decode_model = fast.Transformer(model_args_decode)
6565

6666
fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
67-
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
67+
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu", weights_only=True)
6868
int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
69-
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
69+
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu", weights_only=True)
7070
prefill_model.load_state_dict(fp16_checkpoint, strict=True)
7171
decode_model.load_state_dict(int2_checkpoint, strict=True)
7272

0 commit comments

Comments
 (0)