Skip to content

Commit ddd0401

Browse files
authored
Fix deepcopy/pickle of DAGCircuit variable IO nodes (#14041)
* Fix var DAGOutNode bug in DAGCircuit::pack_into. We were accidentally using NodeType::VarIn when extracting DAGOutNode. * Add release note since this is pre 2.0.0.
1 parent da47ded commit ddd0401

3 files changed

Lines changed: 81 additions & 1 deletion

File tree

crates/circuit/src/dag_circuit.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5708,7 +5708,7 @@ impl DAGCircuit {
57085708
NodeType::ClbitOut(self.clbits.find(&clbit).unwrap())
57095709
} else {
57105710
let var = PyObjectAsKey::new(wire);
5711-
NodeType::VarIn(self.vars.find(&var).unwrap())
5711+
NodeType::VarOut(self.vars.find(&var).unwrap())
57125712
}
57135713
} else if let Ok(op_node) = b.downcast::<DAGOpNode>() {
57145714
let op_node = op_node.borrow();
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed a bug in :class:`~.dagcircuit.DAGCircuit` that would cause
5+
output :class:`~.expr.Var` nodes to become input nodes during
6+
``deepcopy`` and pickling.

test/python/dagcircuit/test_dagcircuit.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,6 +1851,80 @@ def test_present_vars(self):
18511851
right.add_captured_var(a_u8_other)
18521852
self.assertNotEqual(left, right)
18531853

1854+
def test_pickle_vars(self):
1855+
"""Test vars preserved through pickle."""
1856+
a = expr.Var.new("a", types.Bool())
1857+
b = expr.Var.new("b", types.Uint(8))
1858+
1859+
# Check inputs.
1860+
dag = DAGCircuit()
1861+
dag.add_input_var(a)
1862+
1863+
self.assertEqual(dag.num_vars, 1)
1864+
self.assertEqual(dag.num_input_vars, 1)
1865+
1866+
with io.BytesIO() as buf:
1867+
pickle.dump(dag, buf)
1868+
buf.seek(0)
1869+
output = pickle.load(buf)
1870+
1871+
self.assertEqual(output.num_vars, 1)
1872+
self.assertEqual(output.num_input_vars, 1)
1873+
self.assertEqual(output, dag)
1874+
1875+
# Check captures and declarations.
1876+
dag = DAGCircuit()
1877+
dag.add_declared_var(a)
1878+
dag.add_captured_var(b)
1879+
1880+
self.assertEqual(dag.num_vars, 2)
1881+
self.assertEqual(dag.num_captured_vars, 1)
1882+
self.assertEqual(dag.num_declared_vars, 1)
1883+
1884+
with io.BytesIO() as buf:
1885+
pickle.dump(dag, buf)
1886+
buf.seek(0)
1887+
output = pickle.load(buf)
1888+
1889+
self.assertEqual(output.num_vars, 2)
1890+
self.assertEqual(output.num_captured_vars, 1)
1891+
self.assertEqual(output.num_declared_vars, 1)
1892+
self.assertEqual(output, dag)
1893+
1894+
def test_deepcopy_vars(self):
1895+
"""Test vars preserved through deepcopy."""
1896+
a = expr.Var.new("a", types.Bool())
1897+
b = expr.Var.new("b", types.Uint(8))
1898+
1899+
# Check inputs.
1900+
dag = DAGCircuit()
1901+
dag.add_input_var(a)
1902+
1903+
self.assertEqual(dag.num_vars, 1)
1904+
self.assertEqual(dag.num_input_vars, 1)
1905+
1906+
output = copy.deepcopy(dag)
1907+
1908+
self.assertEqual(output.num_vars, 1)
1909+
self.assertEqual(output.num_input_vars, 1)
1910+
self.assertEqual(output, dag)
1911+
1912+
# Check captures and declarations.
1913+
dag = DAGCircuit()
1914+
dag.add_declared_var(a)
1915+
dag.add_captured_var(b)
1916+
1917+
self.assertEqual(dag.num_vars, 2)
1918+
self.assertEqual(dag.num_captured_vars, 1)
1919+
self.assertEqual(dag.num_declared_vars, 1)
1920+
1921+
output = copy.deepcopy(dag)
1922+
1923+
self.assertEqual(output.num_vars, 2)
1924+
self.assertEqual(output.num_captured_vars, 1)
1925+
self.assertEqual(output.num_declared_vars, 1)
1926+
self.assertEqual(output, dag)
1927+
18541928
def test_wires_added_for_simple_classical_vars(self):
18551929
"""Var uses should be represented in the wire structure."""
18561930
a = expr.Var.new("a", types.Bool())

0 commit comments

Comments
 (0)