Skip to content

Commit b8d92a2

Browse files
Wauplinclaudehanouticelina
authored
[Fix] Validate shard filenames in sharded checkpoint index files (#4033)
* Validate shard filenames in sharded checkpoint index files Reject shard references with path traversal or mismatched extensions to prevent a crafted safetensors index from loading pickle payloads. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * remove unrelated tests * Update src/huggingface_hub/serialization/_torch.py Co-authored-by: célina <hanouticelina@gmail.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: célina <hanouticelina@gmail.com>
1 parent 695e73a commit b8d92a2

2 files changed

Lines changed: 72 additions & 4 deletions

File tree

src/huggingface_hub/serialization/_torch.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,14 +510,32 @@ def _load_sharded_checkpoint(
510510
with open(index_file, encoding="utf-8") as f:
511511
index = json.load(f)
512512

513-
# 2. Validate keys if in strict mode
513+
# 2. Validate shard filenames from the index
514+
# This prevents path traversal attacks and extension confusion attacks
515+
# (e.g. a safetensors index referencing .bin pickle files)
516+
expected_extension = Path(filename_pattern.format(suffix="")).suffix # e.g. ".safetensors"
517+
shard_files = list(set(index["weight_map"].values()))
518+
for shard_file in shard_files:
519+
# Reject path traversal (e.g. "../malicious.bin", absolute paths)
520+
if os.path.isabs(shard_file) or ".." in Path(shard_file).parts:
521+
raise ValueError(
522+
f"Invalid shard filename '{shard_file}' in index file '{index_file}'. "
523+
"Shard filenames must be relative paths without '..' components."
524+
)
525+
# Reject extension mismatch (e.g. .bin shard in a .safetensors index)
526+
if not shard_file.endswith(expected_extension):
527+
raise ValueError(
528+
f"Invalid shard filename '{shard_file}' in index file '{index_file}'. "
529+
f"Expected '{expected_extension}' extension to match the index format."
530+
)
531+
532+
# 3. Validate keys if in strict mode
514533
# This is done before loading any shards to fail fast
515534
if strict:
516535
_validate_keys_for_strict_loading(model, index["weight_map"].keys())
517536

518-
# 3. Load each shard using `load_state_dict`
537+
# 4. Load each shard using `load_state_dict`
519538
# Get unique shard files (multiple parameters can be in same shard)
520-
shard_files = list(set(index["weight_map"].values()))
521539
for shard_file in shard_files:
522540
# Load shard into memory
523541
shard_path = os.path.join(save_directory, shard_file)
@@ -531,7 +549,7 @@ def _load_sharded_checkpoint(
531549
# Explicitly remove the state dict from memory
532550
del state_dict
533551

534-
# 4. Return compatibility info
552+
# 5. Return compatibility info
535553
loaded_keys = set(index["weight_map"].keys())
536554
model_keys = set(model.state_dict().keys())
537555
return _IncompatibleKeys(

tests/test_serialization.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,53 @@ def __init__(self):
828828
load_torch_model(model, tmp_path, safe=safe, filename_pattern=filename_pattern)
829829
mock_load.assert_called_once()
830830
assert mock_load.call_args.kwargs["filename_pattern"] == expected_filename_pattern
831+
832+
833+
class TestShardedCheckpointValidation:
834+
"""Regression tests for shard filename validation in sharded checkpoint loading.
835+
836+
See https://github.com/huggingface/hackerone/issues/141 for more details.
837+
838+
Ensures that crafted index files cannot trick the loader into deserializing
839+
unsafe pickle payloads or accessing files outside the checkpoint directory.
840+
"""
841+
842+
def test_safetensors_index_rejects_bin_shard(self, tmp_path):
843+
"""A safetensors index file referencing a .bin shard must be rejected."""
844+
index = {
845+
"metadata": {"total_size": 100},
846+
"weight_map": {
847+
"layer_1": "model-00001-of-00001.bin",
848+
},
849+
}
850+
(tmp_path / "model.safetensors.index.json").write_text(json.dumps(index))
851+
(tmp_path / "model-00001-of-00001.bin").touch()
852+
853+
with pytest.raises(ValueError, match="Invalid shard filename.*Expected '.safetensors' extension"):
854+
load_torch_model(Mock(), tmp_path)
855+
856+
def test_safetensors_index_rejects_path_traversal(self, tmp_path):
857+
"""A shard filename with '..' path traversal must be rejected."""
858+
index = {
859+
"metadata": {"total_size": 100},
860+
"weight_map": {
861+
"layer_1": "../malicious.safetensors",
862+
},
863+
}
864+
(tmp_path / "model.safetensors.index.json").write_text(json.dumps(index))
865+
866+
with pytest.raises(ValueError, match="Invalid shard filename.*without '..' components"):
867+
load_torch_model(Mock(), tmp_path)
868+
869+
def test_safetensors_index_rejects_absolute_path(self, tmp_path):
870+
"""A shard filename with an absolute path must be rejected."""
871+
index = {
872+
"metadata": {"total_size": 100},
873+
"weight_map": {
874+
"layer_1": "/tmp/malicious.safetensors",
875+
},
876+
}
877+
(tmp_path / "model.safetensors.index.json").write_text(json.dumps(index))
878+
879+
with pytest.raises(ValueError, match="Invalid shard filename.*without '..' components"):
880+
load_torch_model(Mock(), tmp_path)

0 commit comments

Comments
 (0)