Skip to content

Commit a6fa6f8

Browse files
jlapeyre1ucian0
andauthored
Fix qpy serialization of substitution of type ParameterExpression (#13890)
* Fix qpy serialization of substitution of type ParameterExpression When substitution history `ParameterExpression._qpy_replay` is serialized, there was no branch for the case that the substituted value is of type `ParameterExpression`. This commit fixes this oversight. * Fix deserializing qpy written by previous writing fix * Added a branch for reading `ParameterExpression` in _qpy_replay * Added a missing argument (version) in an existing call in code path that was previously untested * Remove a bit of useless code introduced in last commit * run black * Add tests * Revert renaming extra_symbols to extra_expressions Reverting this to minimize changes in order to make the PR safer to backport. * Revert cleaning up logic in _encode_replay_subs This was a good change. But not necessary for the bug fix, which is the main point of this PR. * Add qpy compat test * Bind parameter when checking for equality * Run black after merge in browser --------- Co-authored-by: Luciano Bello <bel@zurich.ibm.com>
1 parent 106864c commit a6fa6f8

3 files changed

Lines changed: 76 additions & 0 deletions

File tree

qiskit/qpy/binary_io/value.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def _encode_replay_subs(subs, file_obj, version):
142142

143143

144144
def _write_parameter_expression_v13(file_obj, obj, version):
145+
# A symbol is `Parameter` or `ParameterVectorElement`.
146+
# `symbol_map` maps symbols to ParameterExpression (which may be a symbol).
145147
symbol_map = {}
146148
for inst in obj._qpy_replay:
147149
if isinstance(inst, _SUBS):
@@ -234,9 +236,17 @@ def _write_parameter_expression(file_obj, obj, use_symengine, *, version):
234236
# serialize key
235237
if symbol_key == type_keys.Value.PARAMETER_VECTOR:
236238
symbol_data = common.data_to_binary(symbol, _write_parameter_vec)
239+
elif symbol_key == type_keys.Value.PARAMETER_EXPRESSION:
240+
symbol_data = common.data_to_binary(
241+
symbol,
242+
_write_parameter_expression,
243+
use_symengine=use_symengine,
244+
version=version,
245+
)
237246
else:
238247
symbol_data = common.data_to_binary(symbol, _write_parameter)
239248
# serialize value
249+
240250
value_key, value_data = dumps_value(
241251
symbol, version=version, use_symengine=use_symengine
242252
)
@@ -530,10 +540,13 @@ def _read_parameter_expression_v13(file_obj, vectors, version):
530540
symbol = _read_parameter(file_obj)
531541
elif symbol_key == type_keys.Value.PARAMETER_VECTOR:
532542
symbol = _read_parameter_vec(file_obj, vectors)
543+
elif symbol_key == type_keys.Value.PARAMETER_EXPRESSION:
544+
symbol = _read_parameter_expression_v13(file_obj, vectors, version)
533545
else:
534546
raise exceptions.QpyError(f"Invalid parameter expression map type: {symbol_key}")
535547

536548
elem_key = type_keys.Value(elem_data.type)
549+
537550
binary_data = file_obj.read(elem_data.size)
538551
if elem_key == type_keys.Value.INTEGER:
539552
value = struct.unpack("!q", binary_data)
@@ -548,6 +561,7 @@ def _read_parameter_expression_v13(file_obj, vectors, version):
548561
binary_data,
549562
_read_parameter_expression_v13,
550563
vectors=vectors,
564+
version=version,
551565
)
552566
else:
553567
raise exceptions.QpyError(f"Invalid parameter expression map type: {elem_key}")
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# This code is part of Qiskit.
2+
#
3+
# (C) Copyright IBM 2025.
4+
#
5+
# This code is licensed under the Apache License, Version 2.0. You may
6+
# obtain a copy of this license in the LICENSE.txt file in the root directory
7+
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8+
#
9+
# Any modifications or derivative works of this code must retain this
10+
# copyright notice, and modified files need to carry a notice indicating
11+
# that they have been altered from the originals.
12+
13+
"""Test serializing ParameterExpressions from qpy."""
14+
15+
import io
16+
from test import QiskitTestCase # pylint: disable=wrong-import-order
17+
from qiskit.circuit import Parameter, QuantumCircuit
18+
from qiskit import qpy
19+
20+
21+
class TestQpySerializeParameterExpression(QiskitTestCase):
22+
"""QPY serializing ParameterExpression"""
23+
24+
def test_roundtrip_equal(self):
25+
"""Test serialize deserialize with ParameterExpression in _qpy_replay"""
26+
a = Parameter("a")
27+
b = Parameter("b")
28+
a1 = a * 2
29+
a2 = a1.subs({a: 3 * b})
30+
31+
qc = QuantumCircuit(1)
32+
qc.rz(a2, 0)
33+
34+
use_symengine = True
35+
version = 13
36+
with io.BytesIO() as container:
37+
qpy.dump(qc, container, version=version, use_symengine=use_symengine)
38+
qc_qpy_str = container.getvalue()
39+
40+
with io.BytesIO(qc_qpy_str) as container:
41+
qc_from_qpy = qpy.load(container)[0]
42+
43+
self.assertEqual(qc, qc_from_qpy)

test/qpy_compat/test_qpy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,18 @@ def generate_v12_expr():
833833
return [index, shift]
834834

835835

836+
def generate_replay_with_expression_substitutions():
837+
"""Circuits with parameters that have substituted expressions in the replay"""
838+
a = Parameter("a")
839+
b = Parameter("b")
840+
a1 = a * 2
841+
a2 = a1.subs({a: 3 * b})
842+
qc = QuantumCircuit(1)
843+
qc.rz(a2, 0)
844+
845+
return [qc]
846+
847+
836848
def generate_v14_expr():
837849
"""Circuits that contain expressions and types new in QPY v14."""
838850
from qiskit.circuit.classical import expr, types
@@ -909,6 +921,11 @@ def generate_circuits(version_parts, current_version, load_context=False):
909921
if version_parts >= (1, 1, 0):
910922
output_circuits["standalone_vars.qpy"] = generate_standalone_var()
911923
output_circuits["v12_expr.qpy"] = generate_v12_expr()
924+
if version_parts >= (1, 4, 1):
925+
output_circuits["replay_with_expressions.qpy"] = (
926+
generate_replay_with_expression_substitutions()
927+
)
928+
912929
if version_parts >= (2, 0, 0):
913930
output_circuits["v14_expr.qpy"] = generate_v14_expr()
914931
return output_circuits
@@ -1020,6 +1037,8 @@ def load_qpy(qpy_files, version_parts):
10201037
bind = np.linspace(1.0, 2.0, 22)
10211038
elif path == "parameter_vector_expression.qpy":
10221039
bind = np.linspace(1.0, 2.0, 15)
1040+
elif path == "replay_with_expressions.qpy":
1041+
bind = [2.0]
10231042

10241043
assert_equal(
10251044
circuit, qpy_circuits[i], i, version_parts, bind=bind, equivalent=equivalent

0 commit comments

Comments
 (0)