Skip to content

Commit 7fe29ef

Browse files
depr: Disable pickle deserialization by default for security (#1245)
Adds `allow_pickle=False` parameter to `deserialize_values()` and all calling functions (`job.result()`, `load_job_result()`, `load_job_checkpoint()`). When data is in `PICKLED_V4` format and `allow_pickle` is not explicitly `True`, a `RuntimeError` is raised with clear instructions. Co-authored-by: Cody Wang <speller26@gmail.com>
1 parent fa50b76 commit 7fe29ef

10 files changed

Lines changed: 177 additions & 27 deletions

File tree

src/braket/aws/aws_quantum_job.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def result(
532532
self,
533533
poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT,
534534
poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL,
535+
allow_pickle: bool = False,
535536
) -> dict[str, Any]:
536537
"""Retrieves the hybrid job result persisted using the `save_job_result` function.
537538
@@ -540,12 +541,16 @@ def result(
540541
Default: 10 days.
541542
poll_interval_seconds (float): The polling interval, in seconds, for `result()`.
542543
Default: 5 seconds.
544+
allow_pickle (bool): Whether to allow deserialization of pickled data. Pickle
545+
deserialization can execute arbitrary code and is unsafe on untrusted data.
546+
Default: False.
543547
544548
Returns:
545549
dict[str, Any]: Dict specifying the job results.
546550
547551
Raises:
548-
RuntimeError: if hybrid job is in a FAILED or CANCELLED state.
552+
RuntimeError: if hybrid job is in a FAILED or CANCELLED state, or if data is in
553+
PICKLED_V4 format and allow_pickle is False.
549554
TimeoutError: if hybrid job execution exceeds the polling timeout period.
550555
"""
551556
with tempfile.TemporaryDirectory() as temp_dir:
@@ -557,11 +562,16 @@ def result(
557562
if e.response["Error"]["Code"] == "404":
558563
return {}
559564
raise
560-
return AwsQuantumJob._read_and_deserialize_results(temp_dir, job_name)
565+
return AwsQuantumJob._read_and_deserialize_results(temp_dir, job_name, allow_pickle)
561566

562567
@staticmethod
563-
def _read_and_deserialize_results(temp_dir: str, job_name: str) -> dict[str, Any]:
564-
return load_job_result(Path(temp_dir, job_name, AwsQuantumJob.RESULTS_FILENAME))
568+
def _read_and_deserialize_results(
569+
temp_dir: str, job_name: str, allow_pickle: bool = False
570+
) -> dict[str, Any]:
571+
return load_job_result(
572+
Path(temp_dir, job_name, AwsQuantumJob.RESULTS_FILENAME),
573+
allow_pickle=allow_pickle,
574+
)
565575

566576
def download_result(
567577
self,

src/braket/jobs/_entry_point_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def {function_name}():
1919
finally:
2020
clean_links(links)
2121
if result is not None:
22-
save_job_result(result, data_format=PersistedJobDataFormat.PICKLED_V4)
22+
save_job_result(result, data_format=PersistedJobDataFormat.{data_format})
2323
return result
2424
"""
2525

src/braket/jobs/data_persistence.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def save_job_checkpoint(
6666

6767

6868
def load_job_checkpoint(
69-
job_name: str | None = None, checkpoint_file_suffix: str = ""
69+
job_name: str | None = None,
70+
checkpoint_file_suffix: str = "",
71+
allow_pickle: bool = False,
7072
) -> dict[str, Any]:
7173
"""Loads the job checkpoint data stored for the job named 'job_name', with the checkpoint
7274
file that ends with the `checkpoint_file_suffix`. The `job_name` can refer to any job whose
@@ -87,6 +89,10 @@ def load_job_checkpoint(
8789
`f"{job_name}(_{checkpoint_file_suffix}).json"` is used to locate the
8890
checkpoint file. Default: ""
8991
92+
allow_pickle (bool): Whether to allow deserialization of pickled data. Pickle
93+
deserialization can execute arbitrary code and is unsafe on untrusted data.
94+
Default: False.
95+
9096
Returns:
9197
dict[str, Any]: Dict that contains the checkpoint data persisted in the checkpoint file.
9298
@@ -95,6 +101,7 @@ def load_job_checkpoint(
95101
in the directory specified by the container environment variable `CHECKPOINT_DIR`.
96102
ValueError: If the data stored in the checkpoint file can't be deserialized (possibly due to
97103
corruption).
104+
RuntimeError: If data is in PICKLED_V4 format and allow_pickle is False.
98105
"""
99106
job_name = job_name or get_job_name()
100107
checkpoint_directory = get_checkpoint_dir()
@@ -105,7 +112,9 @@ def load_job_checkpoint(
105112
)
106113
with open(checkpoint_file_path, encoding="utf-8") as f:
107114
persisted_data = PersistedJobData.parse_raw(f.read())
108-
return deserialize_values(persisted_data.dataDictionary, persisted_data.dataFormat)
115+
return deserialize_values(
116+
persisted_data.dataDictionary, persisted_data.dataFormat, allow_pickle
117+
)
109118

110119

111120
def _load_persisted_data(filename: str | Path | None = None) -> PersistedJobData:
@@ -120,19 +129,29 @@ def _load_persisted_data(filename: str | Path | None = None) -> PersistedJobData
120129
)
121130

122131

123-
def load_job_result(filename: str | Path | None = None) -> dict[str, Any]:
132+
def load_job_result(
133+
filename: str | Path | None = None, allow_pickle: bool = False
134+
) -> dict[str, Any]:
124135
"""Loads job result of currently running job.
125136
126137
Args:
127138
filename (str | Path | None): Location of job results. Default `results.json` in job
128139
results directory in a job instance or in working directory locally. This file
129140
must be in the format used by `save_job_result`.
141+
allow_pickle (bool): Whether to allow deserialization of pickled data. Pickle
142+
deserialization can execute arbitrary code and is unsafe on untrusted data.
143+
Default: False.
130144
131145
Returns:
132146
dict[str, Any]: Job result data of current job
147+
148+
Raises:
149+
RuntimeError: If data is in PICKLED_V4 format and allow_pickle is False.
133150
"""
134151
persisted_data = _load_persisted_data(filename)
135-
return deserialize_values(persisted_data.dataDictionary, persisted_data.dataFormat)
152+
return deserialize_values(
153+
persisted_data.dataDictionary, persisted_data.dataFormat, allow_pickle
154+
)
136155

137156

138157
def save_job_result(
@@ -180,6 +199,7 @@ def save_job_result(
180199
current_results = deserialize_values(
181200
current_persisted_data.dataDictionary,
182201
current_persisted_data.dataFormat,
202+
allow_pickle=True, # safe: data was written by this SDK inside the job container
183203
)
184204
updated_results = current_results | result_data
185205

src/braket/jobs/hybrid_job.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from typing import Any
3131

3232
import cloudpickle
33+
from braket.jobs_data import PersistedJobDataFormat
3334

3435
from braket.aws.aws_session import AwsSession
3536
from braket.jobs._entry_point_template import run_entry_point, symlink_input_data
@@ -72,6 +73,7 @@ def hybrid_job(
7273
logger: Logger = getLogger(__name__),
7374
quiet: bool | None = None,
7475
reservation_arn: str | None = None,
76+
data_format: PersistedJobDataFormat | str = PersistedJobDataFormat.PLAINTEXT,
7577
) -> Callable:
7678
"""Defines a hybrid job by decorating the entry point function. The job will be created
7779
when the decorated function is called.
@@ -207,7 +209,7 @@ def job_wrapper(*args: Any, **kwargs: Any) -> Callable:
207209
) as entry_point_file:
208210
template = "\n".join([
209211
_process_input_data(input_data),
210-
_serialize_entry_point(entry_point, args, kwargs),
212+
_serialize_entry_point(entry_point, args, kwargs, data_format),
211213
])
212214
entry_point_file.write(template)
213215

@@ -417,8 +419,32 @@ def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001
417419
cloudpickle.unregister_pickle_by_value(module)
418420

419421

420-
def _serialize_entry_point(entry_point: Callable, args: tuple, kwargs: dict) -> str:
422+
_DATA_FORMAT_ALIASES = {
423+
"pickle": PersistedJobDataFormat.PICKLED_V4,
424+
"plaintext": PersistedJobDataFormat.PLAINTEXT,
425+
}
426+
427+
428+
def _resolve_data_format(data_format: PersistedJobDataFormat | str) -> PersistedJobDataFormat:
429+
if isinstance(data_format, PersistedJobDataFormat):
430+
return data_format
431+
resolved = _DATA_FORMAT_ALIASES.get(data_format.lower())
432+
if resolved is None:
433+
raise ValueError(
434+
f"Unknown data_format '{data_format}'. "
435+
f"Use {list(_DATA_FORMAT_ALIASES.keys())} or a PersistedJobDataFormat enum."
436+
)
437+
return resolved
438+
439+
440+
def _serialize_entry_point(
441+
entry_point: Callable,
442+
args: tuple,
443+
kwargs: dict,
444+
data_format: PersistedJobDataFormat | str = PersistedJobDataFormat.PLAINTEXT,
445+
) -> str:
421446
"""Create an entry point from a function"""
447+
data_format = _resolve_data_format(data_format)
422448
wrapped_entry_point = functools.partial(entry_point, *args, **kwargs)
423449

424450
try:
@@ -434,6 +460,7 @@ def _serialize_entry_point(entry_point: Callable, args: tuple, kwargs: dict) ->
434460
return run_entry_point.format(
435461
serialized=serialized,
436462
function_name=entry_point.__name__,
463+
data_format=data_format.name,
437464
)
438465

439466

src/braket/jobs/local/local_job.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def result(
274274
self,
275275
poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT,
276276
poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL,
277+
allow_pickle: bool = False,
277278
) -> dict[str, Any]:
278279
"""Retrieves the `LocalQuantumJob` result persisted using `save_job_result` function.
279280
@@ -282,17 +283,23 @@ def result(
282283
Default: 10 days.
283284
poll_interval_seconds (float): The polling interval, in seconds, for `result()`.
284285
Default: 5 seconds.
286+
allow_pickle (bool): Whether to allow deserialization of pickled data. Pickle
287+
deserialization can execute arbitrary code and is unsafe on untrusted data.
288+
Default: False.
285289
286290
Raises:
287291
ValueError: The local job directory does not exist.
292+
RuntimeError: If data is in PICKLED_V4 format and allow_pickle is False.
288293
289294
Returns:
290295
dict[str, Any]: Dict specifying the hybrid job results.
291296
"""
292297
try:
293298
with open(os.path.join(self.name, "results.json"), encoding="utf-8") as f:
294299
persisted_data = PersistedJobData.parse_raw(f.read())
295-
return deserialize_values(persisted_data.dataDictionary, persisted_data.dataFormat)
300+
return deserialize_values(
301+
persisted_data.dataDictionary, persisted_data.dataFormat, allow_pickle
302+
)
296303
except FileNotFoundError as e:
297304
raise ValueError(
298305
f"Unable to find results in the local job directory {self.name}."

src/braket/jobs/quantum_job.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def result(
155155
self,
156156
poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT,
157157
poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL,
158+
allow_pickle: bool = False,
158159
) -> dict[str, Any]:
159160
"""Retrieves the hybrid job result persisted using save_job_result() function.
160161
@@ -165,12 +166,16 @@ def result(
165166
poll_interval_seconds (float): The polling interval, in seconds, for `result()`.
166167
Default: 5 seconds.
167168
169+
allow_pickle (bool): Whether to allow deserialization of pickled data. Pickle
170+
deserialization can execute arbitrary code and is unsafe on untrusted data.
171+
Default: False.
168172
169173
Returns:
170174
dict[str, Any]: Dict specifying the hybrid job results.
171175
172176
Raises:
173-
RuntimeError: if hybrid job is in a FAILED or CANCELLED state.
177+
RuntimeError: if hybrid job is in a FAILED or CANCELLED state, or if data is in
178+
PICKLED_V4 format and allow_pickle is False.
174179
TimeoutError: if hybrid job execution exceeds the polling timeout period.
175180
"""
176181

src/braket/jobs/serialization.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,39 @@ def serialize_values(
4545

4646

4747
def deserialize_values(
48-
data_dictionary: dict[str, Any], data_format: PersistedJobDataFormat
48+
data_dictionary: dict[str, Any],
49+
data_format: PersistedJobDataFormat,
50+
allow_pickle: bool = False,
4951
) -> dict[str, Any]:
5052
"""Deserializes the `data_dictionary` values from the format specified by `data_format`.
5153
5254
Args:
5355
data_dictionary (dict[str, Any]): Dict whose values are to be deserialized.
5456
data_format (PersistedJobDataFormat): The data format that the `data_dictionary` values
5557
are currently serialized with.
58+
allow_pickle (bool): Whether to allow deserialization of pickled data. Pickle
59+
deserialization can execute arbitrary code and is unsafe on untrusted data.
60+
Default: False.
5661
5762
Returns:
5863
dict[str, Any]: Dict with same keys as `data_dictionary` and values deserialized from
5964
the specified `data_format` to plaintext.
65+
66+
Raises:
67+
RuntimeError: If data format is PICKLED_V4 and allow_pickle is False.
6068
"""
61-
return (
62-
{k: pickle.loads(codecs.decode(v.encode(), "base64")) for k, v in data_dictionary.items()} # noqa: S301
63-
if data_format == PersistedJobDataFormat.PICKLED_V4
64-
else data_dictionary
65-
)
69+
if data_format == PersistedJobDataFormat.PICKLED_V4:
70+
if not allow_pickle:
71+
raise RuntimeError(
72+
"Data is in PICKLED_V4 format, but pickle deserialization is disabled by "
73+
"default due to security concerns. Pickle deserialization can execute arbitrary "
74+
"code and is unsafe on untrusted data. To enable pickle deserialization, pass "
75+
"allow_pickle=True to the calling function (e.g. job.result(allow_pickle=True), "
76+
"load_job_result(allow_pickle=True), or load_job_checkpoint(allow_pickle=True)). "
77+
"Only do this if you trust the source of the data."
78+
)
79+
return {
80+
k: pickle.loads(codecs.decode(v.encode(), "base64")) # noqa: S301
81+
for k, v in data_dictionary.items()
82+
}
83+
return data_dictionary

test/unit_tests/braket/jobs/test_data_persistence.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,11 @@ def test_load_job_checkpoint(
142142
with open(file_path, "w") as f:
143143
f.write(saved_data)
144144

145+
allow_pickle = data_format == PersistedJobDataFormat.PICKLED_V4
145146
with patch.dict(
146147
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
147148
):
148-
loaded_data = load_job_checkpoint(job_name, file_suffix)
149+
loaded_data = load_job_checkpoint(job_name, file_suffix, allow_pickle=allow_pickle)
149150
assert loaded_data == expected_checkpoint_data
150151

151152

@@ -188,7 +189,7 @@ def test_load_job_checkpoint_raises_error_corrupted_data():
188189
with patch.dict(
189190
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
190191
):
191-
load_job_checkpoint(job_name, file_suffix)
192+
load_job_checkpoint(job_name, file_suffix, allow_pickle=True)
192193

193194

194195
@dataclass
@@ -211,7 +212,7 @@ def test_save_and_load_job_checkpoint():
211212
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
212213
):
213214
save_job_checkpoint(data, data_format=PersistedJobDataFormat.PICKLED_V4)
214-
retrieved = load_job_checkpoint(job_name)
215+
retrieved = load_job_checkpoint(job_name, allow_pickle=True)
215216
assert retrieved == data
216217

217218

@@ -308,7 +309,10 @@ def test_update_result_data(
308309
save_job_result(first_result_data, first_data_format)
309310
save_job_result(second_result_data, second_data_format)
310311

311-
assert load_job_result() == expected_result_data
312+
allow_pickle = first_data_format == PersistedJobDataFormat.PICKLED_V4 or (
313+
second_data_format == PersistedJobDataFormat.PICKLED_V4
314+
)
315+
assert load_job_result(allow_pickle=allow_pickle) == expected_result_data
312316

313317

314318
def test_update_pickled_results_as_plaintext_error():
@@ -322,3 +326,22 @@ def test_update_pickled_results_as_plaintext_error():
322326
)
323327
with pytest.raises(TypeError, match=cannot_convert_pickled_to_plaintext):
324328
save_job_result("hello", PersistedJobDataFormat.PLAINTEXT)
329+
330+
331+
def test_load_job_result_raises_without_allow_pickle():
332+
with tempfile.TemporaryDirectory() as tmp_dir:
333+
with patch.dict(os.environ, {"AMZN_BRAKET_JOB_RESULTS_DIR": tmp_dir}):
334+
save_job_result({"key": "value"}, PersistedJobDataFormat.PICKLED_V4)
335+
with pytest.raises(RuntimeError, match="pickle deserialization is disabled by default"):
336+
load_job_result()
337+
338+
339+
def test_load_job_checkpoint_raises_without_allow_pickle():
340+
with tempfile.TemporaryDirectory() as tmp_dir:
341+
job_name = "test_job"
342+
with patch.dict(
343+
os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name}
344+
):
345+
save_job_checkpoint({"key": "value"}, data_format=PersistedJobDataFormat.PICKLED_V4)
346+
with pytest.raises(RuntimeError, match="pickle deserialization is disabled by default"):
347+
load_job_checkpoint(job_name)

0 commit comments

Comments
 (0)