@@ -66,7 +66,9 @@ def save_job_checkpoint(
6666
6767
6868def load_job_checkpoint (
69- job_name : str | None = None , checkpoint_file_suffix : str = ""
69+ job_name : str | None = None ,
70+ checkpoint_file_suffix : str = "" ,
71+ allow_pickle : bool = False ,
7072) -> dict [str , Any ]:
7173 """Loads the job checkpoint data stored for the job named 'job_name', with the checkpoint
7274 file that ends with the `checkpoint_file_suffix`. The `job_name` can refer to any job whose
@@ -87,6 +89,10 @@ def load_job_checkpoint(
8789 `f"{job_name}(_{checkpoint_file_suffix}).json"` is used to locate the
8890 checkpoint file. Default: ""
8991
92+ allow_pickle (bool): Whether to allow deserialization of pickled data. Pickle
93+ deserialization can execute arbitrary code and is unsafe on untrusted data.
94+ Default: False.
95+
9096 Returns:
9197 dict[str, Any]: Dict that contains the checkpoint data persisted in the checkpoint file.
9298
@@ -95,6 +101,7 @@ def load_job_checkpoint(
95101 in the directory specified by the container environment variable `CHECKPOINT_DIR`.
96102 ValueError: If the data stored in the checkpoint file can't be deserialized (possibly due to
97103 corruption).
104+ RuntimeError: If data is in PICKLED_V4 format and allow_pickle is False.
98105 """
99106 job_name = job_name or get_job_name ()
100107 checkpoint_directory = get_checkpoint_dir ()
@@ -105,7 +112,9 @@ def load_job_checkpoint(
105112 )
106113 with open (checkpoint_file_path , encoding = "utf-8" ) as f :
107114 persisted_data = PersistedJobData .parse_raw (f .read ())
108- return deserialize_values (persisted_data .dataDictionary , persisted_data .dataFormat )
115+ return deserialize_values (
116+ persisted_data .dataDictionary , persisted_data .dataFormat , allow_pickle
117+ )
109118
110119
111120def _load_persisted_data (filename : str | Path | None = None ) -> PersistedJobData :
@@ -120,19 +129,29 @@ def _load_persisted_data(filename: str | Path | None = None) -> PersistedJobData
120129 )
121130
122131
123- def load_job_result (filename : str | Path | None = None ) -> dict [str , Any ]:
132+ def load_job_result (
133+ filename : str | Path | None = None , allow_pickle : bool = False
134+ ) -> dict [str , Any ]:
124135 """Loads job result of currently running job.
125136
126137 Args:
127138 filename (str | Path | None): Location of job results. Default `results.json` in job
128139 results directory in a job instance or in working directory locally. This file
129140 must be in the format used by `save_job_result`.
141+ allow_pickle (bool): Whether to allow deserialization of pickled data. Pickle
142+ deserialization can execute arbitrary code and is unsafe on untrusted data.
143+ Default: False.
130144
131145 Returns:
132146 dict[str, Any]: Job result data of current job
147+
148+ Raises:
149+ RuntimeError: If data is in PICKLED_V4 format and allow_pickle is False.
133150 """
134151 persisted_data = _load_persisted_data (filename )
135- return deserialize_values (persisted_data .dataDictionary , persisted_data .dataFormat )
152+ return deserialize_values (
153+ persisted_data .dataDictionary , persisted_data .dataFormat , allow_pickle
154+ )
136155
137156
138157def save_job_result (
@@ -180,6 +199,7 @@ def save_job_result(
180199 current_results = deserialize_values (
181200 current_persisted_data .dataDictionary ,
182201 current_persisted_data .dataFormat ,
202+ allow_pickle = True , # safe: data was written by this SDK inside the job container
183203 )
184204 updated_results = current_results | result_data
185205
0 commit comments