Skip to content

Commit 301aac0

Browse files
Use struct for NodeDurations in both ALAP and ASAP Scheduling Passes (#15276)
* Initial: Reorganize schedule analysis passes - Add initial ``NodeDurations`` struct. * Fix: Use NodeDurations struct on passes This commit is incomplete and most features will stop working. * Add: Use `NodeDurations` in python. - Add magic methods to `PyNodeDurations`. * Fix: Use indices to get nodes from Dynamical Decoupling passes - Add `items` method. * Lint: Formatting * Fix: Keep original dag nodes and give path to getting them from scratch. - Add an extra field to `PyNodeDurations` to keep mapping between indices and original dag nodes. - Add methods to update `NodeDurations` using mappings with a subset of the current nodes or new mappings as long as the user has the original dag. - Reverted some changes inside of Context Aware DD. * Format: Elide lifetimes. * Apply suggestions from code review Co-authored-by: Alexander Ivrii <alexi@il.ibm.com> * Fix: Add missing magic methods. * Fix: Pre-computing error in `__getitem__` - An oversight allowed an error variant to precompute itself even if the result was `Ok` via `ok_or`. Replaced it with `ok_or_else`. - Fix formatting issue. * Add: `__iter__`, `keys` and `values` to `NodeDurations`. --------- Co-authored-by: Alexander Ivrii <alexi@il.ibm.com>
1 parent e740447 commit 301aac0

10 files changed

Lines changed: 463 additions & 146 deletions

File tree

crates/pyext/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
7272
add_submodule(m, ::qiskit_accelerate::sampled_exp_val::sampled_exp_val, "sampled_exp_val")?;
7373
add_submodule(m, ::qiskit_quantum_info::sparse_observable::sparse_observable, "sparse_observable")?;
7474
add_submodule(m, ::qiskit_quantum_info::sparse_pauli_op::sparse_pauli_op, "sparse_pauli_op")?;
75+
add_submodule(m, ::qiskit_transpiler::passes::scheduling_mod, "scheduling")?;
7576
add_submodule(m, ::qiskit_quantum_info::unitary_sim::unitary_sim, "unitary_sim")?;
7677
add_submodule(m, ::qiskit_transpiler::passes::split_2q_unitaries_mod, "split_2q_unitaries")?;
7778
add_submodule(m, ::qiskit_synthesis::synthesis, "synthesis")?;

crates/transpiler/src/passes/mod.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
//! crate. These are public to be passed to qiskit-pyext and are only used
2222
//! for building Python submodules.
2323
24-
mod alap_schedule_analysis;
2524
mod apply_layout;
26-
mod asap_schedule_analysis;
2725
mod barrier_before_final_measurement;
2826
mod basis_translator;
2927
mod check_map;
@@ -47,6 +45,7 @@ mod optimize_clifford_t;
4745
mod remove_diagonal_gates_before_measure;
4846
mod remove_identity_equiv;
4947
pub mod sabre;
48+
mod schedule_analysis;
5049
mod split_2q_unitaries;
5150
mod substitute_pi4_rotations;
5251
mod synthesize_rz_rotations;
@@ -55,9 +54,7 @@ mod unroll_3q_or_more;
5554
pub mod vf2;
5655
mod wrap_angles;
5756

58-
pub use alap_schedule_analysis::{alap_schedule_analysis_mod, run_alap_schedule_analysis};
5957
pub use apply_layout::{apply_layout, apply_layout_mod, update_layout};
60-
pub use asap_schedule_analysis::{asap_schedule_analysis_mod, run_asap_schedule_analysis};
6158
pub use barrier_before_final_measurement::{
6259
barrier_before_final_measurements_mod, run_barrier_before_final_measurements,
6360
};
@@ -95,6 +92,13 @@ pub use remove_diagonal_gates_before_measure::{
9592
remove_diagonal_gates_before_measure_mod, run_remove_diagonal_before_measure,
9693
};
9794
pub use remove_identity_equiv::{remove_identity_equiv_mod, run_remove_identity_equiv};
95+
pub use schedule_analysis::alap_schedule_analysis::{
96+
alap_schedule_analysis_mod, run_alap_schedule_analysis,
97+
};
98+
pub use schedule_analysis::asap_schedule_analysis::{
99+
asap_schedule_analysis_mod, run_asap_schedule_analysis,
100+
};
101+
pub use schedule_analysis::scheduling_mod;
98102
pub use split_2q_unitaries::{run_split_2q_unitaries, split_2q_unitaries_mod};
99103
pub use substitute_pi4_rotations::{py_run_substitute_pi4_rotations, substitute_pi4_rotations_mod};
100104
pub use synthesize_rz_rotations::{py_run_synthesize_rz_rotations, synthesize_rz_rotations_mod};

crates/transpiler/src/passes/alap_schedule_analysis.rs renamed to crates/transpiler/src/passes/schedule_analysis/alap_schedule_analysis.rs

Lines changed: 21 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -10,52 +10,30 @@
1010
// copyright notice, and modified files need to carry a notice indicating
1111
// that they have been altered from the originals.
1212

13+
use super::TimeOps;
1314
use crate::TranspilerError;
15+
use crate::passes::schedule_analysis::{NodeDurations, PyNodeDurations};
16+
use ahash::RandomState;
1417
use hashbrown::HashMap;
18+
use indexmap::IndexMap;
1519
use pyo3::prelude::*;
16-
use pyo3::types::PyDict;
1720
use qiskit_circuit::dag_circuit::{DAGCircuit, Wire};
18-
use qiskit_circuit::dag_node::{DAGNode, DAGOpNode};
1921
use qiskit_circuit::operations::{OperationRef, StandardInstruction};
2022
use qiskit_circuit::{Clbit, Qubit};
2123
use rustworkx_core::petgraph::prelude::NodeIndex;
22-
use std::ops::{Add, Sub};
23-
24-
pub trait TimeOps: Copy + PartialOrd + Add<Output = Self> + Sub<Output = Self> {
25-
fn zero() -> Self;
26-
fn max<'a>(a: &'a Self, b: &'a Self) -> &'a Self;
27-
}
28-
29-
impl TimeOps for u64 {
30-
fn zero() -> Self {
31-
0
32-
}
33-
fn max<'a>(a: &'a Self, b: &'a Self) -> &'a Self {
34-
if a >= b { a } else { b }
35-
}
36-
}
37-
38-
impl TimeOps for f64 {
39-
fn zero() -> Self {
40-
0.0
41-
}
42-
fn max<'a>(a: &'a Self, b: &'a Self) -> &'a Self {
43-
if a >= b { a } else { b }
44-
}
45-
}
4624

4725
pub fn run_alap_schedule_analysis<T: TimeOps>(
4826
dag: &DAGCircuit,
4927
clbit_write_latency: T,
50-
node_durations: HashMap<NodeIndex, T>,
51-
) -> PyResult<HashMap<NodeIndex, T>> {
28+
node_durations: &IndexMap<NodeIndex, T, RandomState>,
29+
) -> PyResult<IndexMap<NodeIndex, T, RandomState>> {
5230
if dag.qregs().len() != 1 || !dag.qregs_data().contains_key("q") {
5331
return Err(TranspilerError::new_err(
5432
"ALAP schedule runs on physical circuits only",
5533
));
5634
}
5735

58-
let mut node_start_time: HashMap<NodeIndex, T> = HashMap::new();
36+
let mut node_start_time: IndexMap<NodeIndex, T, RandomState> = IndexMap::default();
5937
let mut idle_before: HashMap<Wire, T> = HashMap::new();
6038

6139
let zero = T::zero();
@@ -154,7 +132,7 @@ pub fn run_alap_schedule_analysis<T: TimeOps>(
154132

155133
// Note that ALAP pass is inversely scheduled, thus
156134
// t0 is computed by subtracting t1 from the entire circuit duration.
157-
let mut result: HashMap<NodeIndex, T> = HashMap::new();
135+
let mut result: IndexMap<NodeIndex, T, RandomState> = IndexMap::default();
158136
for (node_idx, t1) in node_start_time {
159137
let final_time = *circuit_duration - t1;
160138
result.insert(node_idx, final_time);
@@ -176,61 +154,25 @@ pub fn run_alap_schedule_analysis<T: TimeOps>(
176154
///
177155
#[pyo3(name = "alap_schedule_analysis", signature= (dag, clbit_write_latency, node_durations))]
178156
pub fn py_run_alap_schedule_analysis(
179-
py: Python,
180157
dag: &DAGCircuit,
181158
clbit_write_latency: u64,
182-
node_durations: &Bound<PyDict>,
183-
) -> PyResult<Py<PyDict>> {
159+
mut node_durations: PyNodeDurations,
160+
) -> PyResult<PyNodeDurations> {
184161
// Extract indices and durations from PyDict
185162
// Get the first duration type
186-
let mut iter = node_durations.iter();
187-
let py_dict = PyDict::new(py);
188-
let Some((_, first_duration)) = iter.next() else {
189-
// Empty circuit.
190-
return Ok(py_dict.into());
191-
};
192-
if first_duration.extract::<u64>().is_ok() {
193-
// All durations are of type u64
194-
let mut op_durations = HashMap::new();
195-
for (py_node, py_duration) in node_durations.iter() {
196-
let node_idx = py_node
197-
.cast_into::<DAGOpNode>()?
198-
.cast_into::<DAGNode>()?
199-
.borrow()
200-
.node
201-
.expect("Node index not found.");
202-
let val = py_duration.extract::<u64>()?;
203-
op_durations.insert(node_idx, val);
204-
}
205-
let node_start_time =
206-
run_alap_schedule_analysis::<u64>(dag, clbit_write_latency, op_durations)?;
207-
for (node_idx, t1) in node_start_time {
208-
let node = dag.get_node(py, node_idx)?;
209-
py_dict.set_item(node, t1)?;
210-
}
211-
} else if first_duration.extract::<f64>().is_ok() {
212-
// All durations are of type f64
213-
let mut op_durations = HashMap::new();
214-
for (py_node, py_duration) in node_durations.iter() {
215-
let node_idx = py_node
216-
.cast_into::<DAGOpNode>()?
217-
.cast_into::<DAGNode>()?
218-
.borrow()
219-
.node
220-
.expect("Node index not found.");
221-
let val = py_duration.extract::<f64>()?;
222-
op_durations.insert(node_idx, val);
163+
// Extract indices and durations from PyDict
164+
// Get the first duration type
165+
let new_durations: NodeDurations = match &*node_durations {
166+
NodeDurations::Dt(node_durations) => {
167+
run_alap_schedule_analysis(dag, clbit_write_latency, node_durations)?.into()
223168
}
224-
let node_start_time =
225-
run_alap_schedule_analysis::<f64>(dag, clbit_write_latency as f64, op_durations)?;
226-
for (node_idx, t1) in node_start_time {
227-
let node = dag.get_node(py, node_idx)?;
228-
py_dict.set_item(node, t1)?;
169+
NodeDurations::Seconds(node_durations) => {
170+
run_alap_schedule_analysis::<f64>(dag, clbit_write_latency as f64, node_durations)?
171+
.into()
229172
}
230-
} else {
231-
return Err(TranspilerError::new_err("Duration must be int or float"));
232-
}
233-
Ok(py_dict.into())
173+
};
174+
node_durations.update_durations(new_durations)?;
175+
Ok(node_durations)
234176
}
235177

236178
pub fn alap_schedule_analysis_mod(m: &Bound<PyModule>) -> PyResult<()> {

crates/transpiler/src/passes/asap_schedule_analysis.rs renamed to crates/transpiler/src/passes/schedule_analysis/asap_schedule_analysis.rs

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,28 @@
1111
// that they have been altered from the originals.
1212

1313
use crate::TranspilerError;
14-
use crate::passes::alap_schedule_analysis::TimeOps;
14+
use crate::passes::schedule_analysis::{NodeDurations, PyNodeDurations, TimeOps};
15+
use ahash::RandomState;
1516
use hashbrown::HashMap;
17+
use indexmap::IndexMap;
1618
use pyo3::prelude::*;
17-
use pyo3::types::PyDict;
1819
use qiskit_circuit::dag_circuit::{DAGCircuit, Wire};
19-
use qiskit_circuit::dag_node::{DAGNode, DAGOpNode};
2020
use qiskit_circuit::operations::{OperationRef, StandardInstruction};
2121
use qiskit_circuit::{Clbit, Qubit};
2222
use rustworkx_core::petgraph::prelude::NodeIndex;
2323

2424
pub fn run_asap_schedule_analysis<T: TimeOps>(
2525
dag: &DAGCircuit,
2626
clbit_write_latency: T,
27-
node_durations: HashMap<NodeIndex, T>,
28-
) -> PyResult<HashMap<NodeIndex, T>> {
27+
node_durations: &IndexMap<NodeIndex, T, RandomState>,
28+
) -> PyResult<IndexMap<NodeIndex, T, RandomState>> {
2929
if dag.qregs().len() != 1 || !dag.qregs_data().contains_key("q") {
3030
return Err(TranspilerError::new_err(
3131
"ASAP schedule runs on physical circuits only",
3232
));
3333
}
3434

35-
let mut node_start_time: HashMap<NodeIndex, T> = HashMap::new();
35+
let mut node_start_time: IndexMap<NodeIndex, T, RandomState> = IndexMap::default();
3636
let mut idle_after: HashMap<Wire, T> = HashMap::new();
3737

3838
let zero = T::zero();
@@ -158,61 +158,23 @@ pub fn run_asap_schedule_analysis<T: TimeOps>(
158158
///
159159
#[pyo3(name = "asap_schedule_analysis", signature= (dag, clbit_write_latency, node_durations))]
160160
pub fn py_run_asap_schedule_analysis(
161-
py: Python,
162161
dag: &DAGCircuit,
163162
clbit_write_latency: u64,
164-
node_durations: &Bound<PyDict>,
165-
) -> PyResult<Py<PyDict>> {
163+
mut node_durations: PyNodeDurations,
164+
) -> PyResult<PyNodeDurations> {
166165
// Extract indices and durations from PyDict
167166
// Get the first duration type
168-
let mut iter = node_durations.iter();
169-
let py_dict = PyDict::new(py);
170-
let Some((_, first_duration)) = iter.next() else {
171-
// Empty circuit.
172-
return Ok(py_dict.into());
173-
};
174-
if first_duration.extract::<u64>().is_ok() {
175-
// All durations are of type u64
176-
let mut op_durations = HashMap::new();
177-
for (py_node, py_duration) in node_durations.iter() {
178-
let node_idx = py_node
179-
.cast_into::<DAGOpNode>()?
180-
.cast_into::<DAGNode>()?
181-
.borrow()
182-
.node
183-
.expect("Node index not found.");
184-
let val = py_duration.extract::<u64>()?;
185-
op_durations.insert(node_idx, val);
186-
}
187-
let node_start_time =
188-
run_asap_schedule_analysis::<u64>(dag, clbit_write_latency, op_durations)?;
189-
for (node_idx, t1) in node_start_time {
190-
let node = dag.get_node(py, node_idx)?;
191-
py_dict.set_item(node, t1)?;
192-
}
193-
} else if first_duration.extract::<f64>().is_ok() {
194-
// All durations are of type f64
195-
let mut op_durations = HashMap::new();
196-
for (py_node, py_duration) in node_durations.iter() {
197-
let node_idx = py_node
198-
.cast_into::<DAGOpNode>()?
199-
.cast_into::<DAGNode>()?
200-
.borrow()
201-
.node
202-
.expect("Node index not found.");
203-
let val = py_duration.extract::<f64>()?;
204-
op_durations.insert(node_idx, val);
167+
let new_durations: NodeDurations = match &*node_durations {
168+
NodeDurations::Dt(node_durations) => {
169+
run_asap_schedule_analysis(dag, clbit_write_latency, node_durations)?.into()
205170
}
206-
let node_start_time =
207-
run_asap_schedule_analysis::<f64>(dag, clbit_write_latency as f64, op_durations)?;
208-
for (node_idx, t1) in node_start_time {
209-
let node = dag.get_node(py, node_idx)?;
210-
py_dict.set_item(node, t1)?;
171+
NodeDurations::Seconds(node_durations) => {
172+
run_asap_schedule_analysis::<f64>(dag, clbit_write_latency as f64, node_durations)?
173+
.into()
211174
}
212-
} else {
213-
return Err(TranspilerError::new_err("Duration must be int or float"));
214-
}
215-
Ok(py_dict.into())
175+
};
176+
node_durations.update_durations(new_durations)?;
177+
Ok(node_durations)
216178
}
217179

218180
pub fn asap_schedule_analysis_mod(m: &Bound<PyModule>) -> PyResult<()> {

0 commit comments

Comments
 (0)