Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions qiskit/transpiler/passes/optimization/collect_1q_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,35 @@

"""Collect sequences of uninterrupted gates acting on 1 qubit."""

from collections.abc import Callable

from qiskit.transpiler.basepasses import AnalysisPass
from qiskit.dagcircuit import DAGCircuit, DAGOpNode


class Collect1qRuns(AnalysisPass):
"""Collect one-qubit subcircuits."""

def __init__(self, filter_fn: Callable[[DAGCircuit, list[DAGOpNode]], bool] | None = None):
"""
Args:
filter_fn: An optional function that filters collected one-qubit runs.
"""
self.filter_fn = filter_fn
super().__init__()

def run(self, dag):
"""Run the Collect1qBlocks pass on `dag`.

The blocks contain "op" nodes in topological order such that all gates
in a block act on the same qubits and are adjacent in the circuit.
in a block act on the same qubits, are adjacent in the circuit, and
satisfy the filtering condition (when specified).

After the execution, ``property_set['run_list']`` is set to a list of
tuples of "op" node.
"""
self.property_set["run_list"] = dag.collect_1q_runs()
run_list = dag.collect_1q_runs()
if self.filter_fn is not None:
run_list = [run for run in run_list if self.filter_fn(dag, run)]
self.property_set["run_list"] = run_list
return dag
19 changes: 17 additions & 2 deletions qiskit/transpiler/passes/optimization/collect_2q_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,38 @@
"""Collect sequences of uninterrupted gates acting on 2 qubits."""

from collections import defaultdict
from collections.abc import Callable

from qiskit.transpiler.basepasses import AnalysisPass
from qiskit.dagcircuit import DAGCircuit, DAGOpNode


class Collect2qBlocks(AnalysisPass):
"""Collect two-qubit subcircuits."""

def __init__(self, filter_fn: Callable[[DAGCircuit, list[DAGOpNode]], bool] | None = None):
"""
Args:
filter_fn: An optional function that filters collected two-qubit blocks.
"""
self.filter_fn = filter_fn
super().__init__()

def run(self, dag):
"""Run the Collect2qBlocks pass on `dag`.

The blocks contain "op" nodes in topological order such that all gates
in a block act on the same qubits and are adjacent in the circuit.
in a block act on the same qubits, are adjacent in the circuit, and
satisfy the filtering condition (when specified).

After the execution, ``property_set['block_list']`` is set to a list of
tuples of "op" node.
"""
self.property_set["commutation_set"] = defaultdict(list)
self.property_set["block_list"] = dag.collect_2q_runs()

block_list = dag.collect_2q_runs()
if self.filter_fn is not None:
block_list = [block for block in block_list if self.filter_fn(dag, block)]
self.property_set["block_list"] = block_list

return dag
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features_transpiler:
- |
Added a new ``filter_fn`` argument to the analysis passes :class:`.Collect1qRuns` and
:class:`.Collect2qBlocks`. This optional callable allows to filter collected circuits
based on custom criteria, for example to return only single-qubits runs with at least
one Hadamard gate, or only 2-qubit blocks with at least 3 CX-gates.
97 changes: 97 additions & 0 deletions test/python/transpiler/test_consolidate_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,3 +767,100 @@ def test_invalid_python_data_does_not_panic(self):
pass_.property_set["block_list"] = [[not_an_op_node]]
with self.assertRaisesRegex(IndexError, "node index.*was not a valid operation"):
pass_.run(dag)


class TestCollect1qRuns(QiskitTestCase):
"""
Additional correctness tests for the Collect1qRuns transpiler pass.
"""

def test_filter(self):
"""Test filter_fn argument."""
qc = QuantumCircuit(2)
qc.h(0)
qc.t(0)
qc.cx(0, 1)
qc.h(1)
qc.s(1)
qc.tdg(1)

with self.subTest("filter_fn is not specified"):
pm = PassManager([Collect1qRuns()])
pm.run(qc)
runs = pm.property_set.get("run_list", [])
self.assertEqual(len(runs), 2)

with self.subTest("filter_fn is None"):
pm = PassManager([Collect1qRuns(filter_fn=None)])
pm.run(qc)
runs = pm.property_set.get("run_list", [])
self.assertEqual(len(runs), 2)

with self.subTest("only runs with at least one S-gate"):

def at_least_one_s_gate(_dag, run):
return any(node.op.name == "s" for node in run)

pm = PassManager([Collect1qRuns(filter_fn=at_least_one_s_gate)])
pm.run(qc)
runs = pm.property_set.get("run_list", [])
self.assertEqual(len(runs), 1)

with self.subTest("only runs without H-gates"):

def no_h_gates(_dag, run):
return all(node.op.name != "h" for node in run)

pm = PassManager([Collect1qRuns(filter_fn=no_h_gates)])
pm.run(qc)
runs = pm.property_set.get("run_list", [])
self.assertEqual(len(runs), 0)


class TestCollect2qBlocks(QiskitTestCase):
"""
Additional correctness tests for the Collect2qBlocks transpiler pass.
"""

def test_filter(self):
"""Test filter_fn argument."""
qc = QuantumCircuit(3)
qc.h(0) # first block
qc.cx(0, 1) # first block
qc.cx(1, 2) # second block
qc.cx(2, 1) # second block
qc.cx(1, 0) # third block
qc.cx(0, 1) # third block
qc.cx(1, 0) # third block

with self.subTest("filter_fn is not specified"):
pm = PassManager([Collect2qBlocks()])
pm.run(qc)
blocks = pm.property_set.get("block_list", [])
self.assertEqual(len(blocks), 3)

with self.subTest("filter_fn is None"):
pm = PassManager([Collect2qBlocks(filter_fn=None)])
pm.run(qc)
blocks = pm.property_set.get("block_list", [])
self.assertEqual(len(blocks), 3)

with self.subTest("only blocks with at least 3 gates"):

def at_least_three_gates(_dag, run):
return len(run) >= 3

pm = PassManager([Collect2qBlocks(filter_fn=at_least_three_gates)])
pm.run(qc)
blocks = pm.property_set.get("block_list", [])
self.assertEqual(len(blocks), 1)

with self.subTest("only blocks with at least 4 gates"):

def at_least_four_gates(_dag, run):
return len(run) >= 4

pm = PassManager([Collect2qBlocks(filter_fn=at_least_four_gates)])
pm.run(qc)
blocks = pm.property_set.get("block_list", [])
self.assertEqual(len(blocks), 0)