Skip to content

Commit fa188f6

Browse files
Fix serialization tests for torch>=2.6.0
1 parent 431819d commit fa188f6

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

tests/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def torch_save_to_buffer(obj):
2222

2323
def torch_load_from_buffer(buffer):
2424
buffer.seek(0)
25-
obj = torch.load(buffer)
25+
obj = torch.load(buffer, weights_only=False)
2626
buffer.seek(0)
2727
return obj
2828

tests/test_linear8bitlt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_linear_serialization(
118118
if not has_fp16_weights:
119119
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
120120

121-
new_state_dict = torch.load(state_path_8bit)
121+
new_state_dict = torch.load(state_path_8bit, weights_only=False)
122122

123123
new_linear_custom = Linear8bitLt(
124124
linear.in_features,

0 commit comments

Comments
 (0)