Skip to content

Commit 0551a40

Browse files
phalakbhraynelfss
andauthored
Port ConstrainedReschedule to Rust (#14883)
* Initial commit, port constrained reschedule pass to Rust * Complete remaining portions of the logic implementation * Fix lint * Fix lint * Implement suggested changes * Fix lint * Fix lint * Fix: Use new `NodeDurations` struct in `ConstrainedReschedule` * Fix: Redundant casting. * Docs: Update header string to use https * Fix: Address review comments - Added check for `Delay` and directive operations in `_push_node`. - Use `saturating_sub` when calculating the shift variable in the same method. --------- Co-authored-by: Raynel Sanchez <raynelfss@hotmail.com>
1 parent f0d87ec commit 0551a40

6 files changed

Lines changed: 363 additions & 127 deletions

File tree

crates/pyext/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
4545
add_submodule(m, ::qiskit_transpiler::commutation_checker::commutation_checker, "commutation_checker")?;
4646
add_submodule(m, ::qiskit_transpiler::passes::commutative_optimization_mod, "commutative_optimization")?;
4747
add_submodule(m, ::qiskit_transpiler::passes::consolidate_blocks_mod, "consolidate_blocks")?;
48+
add_submodule(m, ::qiskit_transpiler::passes::constrained_reschedule_mod, "constrained_reschedule")?;
4849
add_submodule(m, ::qiskit_synthesis::linalg::cos_sin_decomp::cos_sin_decomp, "cos_sin_decomp")?;
4950
add_submodule(m, ::qiskit_transpiler::passes::dense_layout_mod, "dense_layout")?;
5051
add_submodule(m, ::qiskit_transpiler::equivalence::equivalence, "equivalence")?;
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
// This code is part of Qiskit.
2+
//
3+
// (C) Copyright IBM 2025.
4+
//
5+
// This code is licensed under the Apache License, Version 2.0. You may
6+
// obtain a copy of this license in the LICENSE.txt file in the root directory
7+
// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0.
8+
//
9+
// Any modifications or derivative works of this code must retain this
10+
// copyright notice, and modified files need to carry a notice indicating
11+
// that they have been altered from the originals.
12+
13+
use crate::TranspilerError;
14+
use crate::passes::schedule_analysis::{NodeDurations, PyNodeDurations};
15+
use crate::target::Target;
16+
use ::hashbrown::HashSet;
17+
use ahash::RandomState;
18+
use indexmap::IndexMap;
19+
use pyo3::exceptions::PyValueError;
20+
use pyo3::prelude::*;
21+
use pyo3::{Bound, PyResult, pyfunction, wrap_pyfunction};
22+
use qiskit_circuit::PhysicalQubit;
23+
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType};
24+
use qiskit_circuit::operations::Param;
25+
use qiskit_circuit::operations::{Operation, OperationRef, StandardInstruction};
26+
use rustworkx_core::petgraph::stable_graph::NodeIndex;
27+
28+
/// Returns the immediate successor operation nodes of a given node in the DAG.
29+
///
30+
/// This function traverses the DAG to find all nodes that are direct successors
31+
/// of the given node and filters them to return only operation nodes.
32+
///
33+
/// # Arguments
34+
///
35+
/// * `dag` - Reference to the DAGCircuit containing the quantum circuit
36+
/// * `node_index` - Index of the node whose successors we want to find
37+
///
38+
/// # Returns
39+
///
40+
/// An iterator of `NodeIndex` values representing the immediate successor operation nodes.
41+
fn get_next_gate(dag: &DAGCircuit, node_index: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
42+
dag.quantum_successors(node_index)
43+
.filter(|&idx| matches!(dag[idx], NodeType::Operation(_)))
44+
}
45+
46+
/// Update the start time of the current node to satisfy alignment constraints.
47+
/// Immediate successors are pushed back to avoid overlap and will be processed later.
48+
///
49+
/// Note:
50+
/// This logic assumes that all bits in the qregs and cregs synchronously start and end,
51+
/// i.e. occupy the same time slot, but qregs and cregs can take different time slots
52+
/// due to classical I/O latencies.
53+
///
54+
/// # Args:
55+
/// * `py` - Python interpreter reference for PyO3 operations
56+
/// * `dag` - Reference to the DAGCircuit to be rescheduled with constraints
57+
/// * `node_index` - Index of the current node to be processed
58+
/// * `node_start_time` - Mutable Python dictionary mapping node indices to start times
59+
/// * `clbit_write_latency` - Additional latency for classical bit write operations
60+
/// * `pulse_align` - Alignment constraint for gate operations (in dt units)
61+
/// * `acquire_align` - Alignment constraint for measurement/reset operations (in dt units)
62+
/// * `target` - Optional target backend for duration information
63+
fn push_node_back(
64+
dag: &DAGCircuit,
65+
node_index: NodeIndex,
66+
node_start_time: &mut IndexMap<NodeIndex<u32>, u64, RandomState>,
67+
clbit_write_latency: u32,
68+
pulse_align: u32,
69+
acquire_align: u32,
70+
target: Option<&Target>,
71+
) -> PyResult<()> {
72+
let NodeType::Operation(op) = &dag[node_index] else {
73+
unreachable!("topological_op_nodes() should only return operations.")
74+
};
75+
76+
let op_view = op.op.view();
77+
let alignment = match op_view {
78+
OperationRef::Gate(_) | OperationRef::StandardGate(_) => Some(pulse_align),
79+
OperationRef::StandardInstruction(StandardInstruction::Reset)
80+
| OperationRef::StandardInstruction(StandardInstruction::Measure) => Some(acquire_align),
81+
OperationRef::StandardInstruction(StandardInstruction::Delay(_)) => None,
82+
_ => {
83+
if !op_view.directive() {
84+
None
85+
} else {
86+
return Err(TranspilerError::new_err(format!(
87+
"Unknown operation type for '{}'.",
88+
op_view.name()
89+
)));
90+
}
91+
}
92+
};
93+
94+
let mut this_t0: u64 = *node_start_time
95+
.get(&node_index)
96+
.ok_or_else(|| PyValueError::new_err("Missing value in node_start_time"))?;
97+
98+
if let Some(alignment) = alignment {
99+
let misalignment = this_t0 % alignment as u64;
100+
let shift = if misalignment != 0 {
101+
(alignment as u64).saturating_sub(misalignment)
102+
} else {
103+
0
104+
};
105+
this_t0 += shift;
106+
node_start_time
107+
.entry(node_index)
108+
.and_modify(|old_t0| *old_t0 = this_t0)
109+
.or_insert(this_t0);
110+
}
111+
112+
let new_t1q = if let Some(target) = target {
113+
let qargs: Vec<PhysicalQubit> = dag
114+
.qargs_interner()
115+
.get(op.qubits)
116+
.iter()
117+
.map(|q| PhysicalQubit(q.index() as u32))
118+
.collect();
119+
let duration = target.get_duration(op.op.name(), &qargs).unwrap_or(0.0);
120+
this_t0 + duration as u64
121+
} else if matches!(
122+
op_view,
123+
OperationRef::StandardInstruction(StandardInstruction::Delay(_))
124+
) {
125+
let params = op.params_view();
126+
let param = params
127+
.first()
128+
.ok_or_else(|| PyValueError::new_err("Delay instruction missing duration parameter"))?;
129+
let duration = match param {
130+
Param::Obj(val) => {
131+
// Try to extract as different numeric types
132+
Python::attach(|py| val.bind(py).extract::<u64>())
133+
}
134+
Param::Float(f) => Ok(*f as u64),
135+
_ => Err(TranspilerError::new_err(
136+
"The provided Delay duration is not in terms of dt.",
137+
)),
138+
}?;
139+
140+
this_t0 + duration
141+
} else {
142+
this_t0
143+
};
144+
145+
let this_qubits: HashSet<_> = dag
146+
.qargs_interner()
147+
.get(op.qubits)
148+
.iter()
149+
.map(|q| q.index())
150+
.collect();
151+
152+
// Handle classical bits based on operation type
153+
let (new_t1c, this_clbits) = if matches!(
154+
op_view,
155+
OperationRef::StandardInstruction(StandardInstruction::Measure)
156+
| OperationRef::StandardInstruction(StandardInstruction::Reset)
157+
) {
158+
// creg access ends at the end of instruction
159+
let new_t1c = Some(new_t1q);
160+
let this_clbits: HashSet<_> = dag
161+
.cargs_interner()
162+
.get(op.clbits)
163+
.iter()
164+
.map(|c| c.index())
165+
.collect();
166+
(new_t1c, this_clbits)
167+
} else {
168+
(None, HashSet::new())
169+
};
170+
// Check immediate successors for overlap
171+
for next_node_index in get_next_gate(dag, node_index) {
172+
// Get the next node
173+
let NodeType::Operation(next_node) = &dag[next_node_index] else {
174+
unreachable!("topological_op_nodes() should only return operations.")
175+
};
176+
177+
// Compute next node start time separately for qreg and creg
178+
let next_t0q: u64 = node_start_time
179+
.get(&next_node_index)
180+
.copied()
181+
.expect("Expected value in node_start_time for next_node_index");
182+
183+
let next_qubits: HashSet<_> = dag
184+
.qargs_interner()
185+
.get(next_node.qubits)
186+
.iter()
187+
.map(|q| q.index())
188+
.collect();
189+
190+
let next_op_view = next_node.op.view();
191+
let (next_t0c, next_clbits) = if matches!(
192+
next_op_view,
193+
OperationRef::StandardInstruction(StandardInstruction::Measure)
194+
| OperationRef::StandardInstruction(StandardInstruction::Reset)
195+
) {
196+
// creg access starts after write latency
197+
let next_t0c = Some(next_t0q + clbit_write_latency as u64);
198+
let next_clbits: HashSet<_> = dag
199+
.cargs_interner()
200+
.get(next_node.clbits)
201+
.iter()
202+
.map(|c| c.index())
203+
.collect();
204+
(next_t0c, next_clbits)
205+
} else {
206+
(None, HashSet::new())
207+
};
208+
209+
// Compute overlap if there is qubits overlap
210+
let qreg_overlap = if !this_qubits.is_disjoint(&next_qubits) {
211+
new_t1q - next_t0q
212+
} else {
213+
0
214+
};
215+
216+
// Compute overlap if there is clbits overlap
217+
let creg_overlap = if !this_clbits.is_empty()
218+
&& !next_clbits.is_empty()
219+
&& !this_clbits.is_disjoint(&next_clbits)
220+
{
221+
if let (Some(t1c), Some(t0c)) = (new_t1c, next_t0c) {
222+
t1c - t0c
223+
} else {
224+
0
225+
}
226+
} else {
227+
0
228+
};
229+
230+
// Shift next node if there is finite overlap in either qubits or clbits
231+
let overlap = qreg_overlap.max(creg_overlap);
232+
if overlap > 0 {
233+
let new_start_time = next_t0q + overlap;
234+
node_start_time
235+
.entry(next_node_index)
236+
.and_modify(|old_start| *old_start = new_start_time)
237+
.or_insert(new_start_time);
238+
}
239+
}
240+
Ok(())
241+
}
242+
243+
#[pyfunction]
244+
#[pyo3(name="constrained_reschedule", signature=(dag, node_start_time, clbit_write_latency, acquire_align, pulse_align, target))]
245+
pub fn py_run_constrained_reschedule(
246+
dag: &DAGCircuit,
247+
mut node_start_time: PyNodeDurations,
248+
clbit_write_latency: u32,
249+
acquire_align: u32,
250+
pulse_align: u32,
251+
target: Option<&Target>,
252+
) -> PyResult<PyNodeDurations> {
253+
let NodeDurations::Dt(durations) = &mut *node_start_time else {
254+
return Err(TranspilerError::new_err(
255+
"The durations provided have not been properly converted to 'dt' units.",
256+
));
257+
};
258+
run_constrained_reschedule(
259+
dag,
260+
durations,
261+
clbit_write_latency,
262+
acquire_align,
263+
pulse_align,
264+
target,
265+
)?;
266+
Ok(node_start_time)
267+
}
268+
269+
pub fn run_constrained_reschedule(
270+
dag: &DAGCircuit,
271+
node_start_time: &mut IndexMap<NodeIndex<u32>, u64, RandomState>,
272+
clbit_write_latency: u32,
273+
acquire_align: u32,
274+
pulse_align: u32,
275+
target: Option<&Target>,
276+
) -> PyResult<()> {
277+
for node_index in dag.topological_op_nodes(false) {
278+
let start_time = node_start_time.get(&node_index);
279+
let val = *start_time.ok_or_else(|| {
280+
TranspilerError::new_err(format!(
281+
"Missing start time for node {}. Run scheduler again.",
282+
node_index.index()
283+
))
284+
})?;
285+
286+
if val == 0 {
287+
continue;
288+
}
289+
290+
push_node_back(
291+
dag,
292+
node_index,
293+
node_start_time,
294+
clbit_write_latency,
295+
acquire_align,
296+
pulse_align,
297+
target,
298+
)?;
299+
}
300+
301+
Ok(())
302+
}
303+
304+
pub fn constrained_reschedule_mod(m: &Bound<PyModule>) -> PyResult<()> {
305+
m.add_wrapped(wrap_pyfunction!(py_run_constrained_reschedule))?;
306+
Ok(())
307+
}

crates/transpiler/src/passes/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ mod commutation_analysis;
2929
mod commutation_cancellation;
3030
mod commutative_optimization;
3131
mod consolidate_blocks;
32+
mod constrained_reschedule;
3233
mod convert_to_pauli_rotations;
3334
mod dense_layout;
3435
mod disjoint_layout;
@@ -64,6 +65,7 @@ pub use commutation_analysis::{analyze_commutations, commutation_analysis_mod};
6465
pub use commutation_cancellation::{cancel_commutations, commutation_cancellation_mod};
6566
pub use commutative_optimization::{commutative_optimization_mod, run_commutative_optimization};
6667
pub use consolidate_blocks::{DecomposerType, consolidate_blocks_mod, run_consolidate_blocks};
68+
pub use constrained_reschedule::{constrained_reschedule_mod, run_constrained_reschedule};
6769
pub use convert_to_pauli_rotations::{
6870
convert_to_pauli_rotations_mod, py_convert_to_pauli_rotations,
6971
};

crates/transpiler/src/passes/schedule_analysis/mod.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
pub mod alap_schedule_analysis;
1414
pub mod asap_schedule_analysis;
1515

16-
use std::ops::{Add, Deref, Sub};
16+
use std::ops::{
17+
Add, AddAssign, Deref, DerefMut, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign,
18+
};
1719

1820
use ahash::RandomState;
1921
use hashbrown::HashMap;
@@ -32,7 +34,38 @@ use rustworkx_core::petgraph::graph::NodeIndex;
3234

3335
use crate::TranspilerError;
3436

35-
pub trait TimeOps: Copy + PartialOrd + Add<Output = Self> + Sub<Output = Self> {
37+
pub trait Number:
38+
Copy
39+
+ Add<Output = Self>
40+
+ Sub<Output = Self>
41+
+ Mul<Output = Self>
42+
+ Div<Output = Self>
43+
+ Rem<Output = Self>
44+
+ AddAssign
45+
+ SubAssign
46+
+ MulAssign
47+
+ DivAssign
48+
+ RemAssign
49+
{
50+
}
51+
52+
impl<
53+
T: Copy
54+
+ Add<Output = Self>
55+
+ Sub<Output = Self>
56+
+ Mul<Output = Self>
57+
+ Div<Output = Self>
58+
+ Rem<Output = Self>
59+
+ AddAssign
60+
+ SubAssign
61+
+ MulAssign
62+
+ DivAssign
63+
+ RemAssign,
64+
> Number for T
65+
{
66+
}
67+
68+
pub trait TimeOps: Copy + PartialOrd + Number {
3669
fn zero() -> Self;
3770
fn max<'a>(a: &'a Self, b: &'a Self) -> &'a Self;
3871
}
@@ -77,6 +110,12 @@ impl Deref for PyNodeDurations {
77110
}
78111
}
79112

113+
impl DerefMut for PyNodeDurations {
114+
fn deref_mut(&mut self) -> &mut Self::Target {
115+
&mut self.inner
116+
}
117+
}
118+
80119
#[pymethods]
81120
impl PyNodeDurations {
82121
#[new]

0 commit comments

Comments
 (0)