Skip to content

Commit cbe07be

Browse files
authored
Add pickle support of TwoQubitControlledUDecomposer (#14038)
* add pickle support * add test for pickle support * update test_assertPickle * add impl IntoPyObject for RXXEquivalent * add a test for pickle with dill * update tests names * update following review * add back another impl
1 parent d41728f commit cbe07be

2 files changed

Lines changed: 62 additions & 0 deletions

File tree

crates/accelerate/src/two_qubit_decompose.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,6 +2479,28 @@ impl RXXEquivalent {
24792479
}
24802480
}
24812481
}
2482+
impl<'a, 'py> IntoPyObject<'py> for &'a RXXEquivalent {
2483+
type Target = PyAny;
2484+
type Output = Borrowed<'a, 'py, Self::Target>;
2485+
type Error = PyErr;
2486+
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
2487+
match self {
2488+
RXXEquivalent::Standard(gate) => Ok(gate.get_gate_class(py)?.bind_borrowed(py)),
2489+
RXXEquivalent::CustomPython(gate) => Ok(gate.as_any().bind_borrowed(py)),
2490+
}
2491+
}
2492+
}
2493+
impl<'py> IntoPyObject<'py> for RXXEquivalent {
2494+
type Target = PyAny;
2495+
type Output = Bound<'py, Self::Target>;
2496+
type Error = PyErr;
2497+
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
2498+
match self {
2499+
RXXEquivalent::Standard(gate) => Ok(gate.get_gate_class(py)?.bind(py).clone()),
2500+
RXXEquivalent::CustomPython(gate) => Ok(gate.bind(py).clone().into_any()),
2501+
}
2502+
}
2503+
}
24822504

24832505
#[derive(Clone, Debug)]
24842506
#[pyclass(module = "qiskit._accelerate.two_qubit_decompose", subclass)]
@@ -2920,6 +2942,10 @@ impl TwoQubitControlledUDecomposer {
29202942

29212943
#[pymethods]
29222944
impl TwoQubitControlledUDecomposer {
2945+
fn __getnewargs__(&self) -> (&RXXEquivalent, &str) {
2946+
(&self.rxx_equivalent_gate, self.euler_basis.as_str())
2947+
}
2948+
29232949
/// Initialize the KAK decomposition.
29242950
/// Args:
29252951
/// rxx_equivalent_gate: Gate that is locally equivalent to an :class:`.RXXGate`:

test/python/synthesis/test_synthesis.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import contextlib
1818
import logging
1919
import math
20+
import dill
2021
import numpy as np
2122
import scipy
2223
import scipy.stats
@@ -1544,6 +1545,41 @@ def inverse(self, annotated: bool = False):
15441545
circ = decomposer(unitary)
15451546
self.assertEqual(Operator(unitary), Operator(circ))
15461547

1548+
def test_assert_pickle(self):
1549+
"""Assert that TwoQubitControlledUDecomposer supports pickle"""
1550+
1551+
decomp = TwoQubitControlledUDecomposer(RXXGate)
1552+
1553+
pkl = pickle.dumps(decomp, protocol=max(4, pickle.DEFAULT_PROTOCOL))
1554+
decomp_cpy = pickle.loads(pkl)
1555+
msg_base = f"decomp:\n{decomp}\ndecomp_cpy:\n{decomp_cpy}"
1556+
self.assertEqual(type(decomp), type(decomp_cpy), msg_base)
1557+
self.assertEqual(decomp.rxx_equivalent_gate, decomp_cpy.rxx_equivalent_gate, msg=msg_base)
1558+
self.assertEqual(decomp.euler_basis, decomp_cpy.euler_basis, msg=msg_base)
1559+
1560+
def test_assert_pickle_with_dill(self):
1561+
"""Assert that TwoQubitControlledUDecomposer supports pickle"""
1562+
1563+
class CustomXXGate(RXXGate):
1564+
"""Custom RXXGate subclass that's not a standard gate"""
1565+
1566+
_standard_gate = None
1567+
1568+
def __init__(self, theta, label=None):
1569+
super().__init__(theta, label)
1570+
self.name = "MyCustomXXGate"
1571+
1572+
decomp = TwoQubitControlledUDecomposer(CustomXXGate)
1573+
1574+
pkl = dill.dumps(decomp, protocol=max(4, pickle.DEFAULT_PROTOCOL))
1575+
decomp_cpy = dill.loads(pkl)
1576+
msg_base = f"decomp:\n{decomp}\ndecomp_cpy:\n{decomp_cpy}"
1577+
self.assertEqual(type(decomp), type(decomp_cpy), msg_base)
1578+
self.assertEqual(
1579+
decomp.rxx_equivalent_gate.name, decomp_cpy.rxx_equivalent_gate.name, msg=msg_base
1580+
)
1581+
self.assertEqual(decomp.euler_basis, decomp_cpy.euler_basis, msg=msg_base)
1582+
15471583

15481584
class TestDecomposeProductRaises(QiskitTestCase):
15491585
"""Check that exceptions are raised when 2q matrix is not a product of 1q unitaries"""

0 commit comments

Comments
 (0)