|
| 1 | +// This code is part of Qiskit. |
| 2 | +// |
| 3 | +// (C) Copyright IBM 2025. |
| 4 | +// |
| 5 | +// This code is licensed under the Apache License, Version 2.0. You may |
| 6 | +// obtain a copy of this license in the LICENSE.txt file in the root directory |
| 7 | +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. |
| 8 | +// |
| 9 | +// Any modifications or derivative works of this code must retain this |
| 10 | +// copyright notice, and modified files need to carry a notice indicating |
| 11 | +// that they have been altered from the originals. |
| 12 | + |
| 13 | +use crate::TranspilerError; |
| 14 | +use crate::passes::schedule_analysis::{NodeDurations, PyNodeDurations}; |
| 15 | +use crate::target::Target; |
| 16 | +use ::hashbrown::HashSet; |
| 17 | +use ahash::RandomState; |
| 18 | +use indexmap::IndexMap; |
| 19 | +use pyo3::exceptions::PyValueError; |
| 20 | +use pyo3::prelude::*; |
| 21 | +use pyo3::{Bound, PyResult, pyfunction, wrap_pyfunction}; |
| 22 | +use qiskit_circuit::PhysicalQubit; |
| 23 | +use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType}; |
| 24 | +use qiskit_circuit::operations::Param; |
| 25 | +use qiskit_circuit::operations::{Operation, OperationRef, StandardInstruction}; |
| 26 | +use rustworkx_core::petgraph::stable_graph::NodeIndex; |
| 27 | + |
| 28 | +/// Returns the immediate successor operation nodes of a given node in the DAG. |
| 29 | +/// |
| 30 | +/// This function traverses the DAG to find all nodes that are direct successors |
| 31 | +/// of the given node and filters them to return only operation nodes. |
| 32 | +/// |
| 33 | +/// # Arguments |
| 34 | +/// |
| 35 | +/// * `dag` - Reference to the DAGCircuit containing the quantum circuit |
| 36 | +/// * `node_index` - Index of the node whose successors we want to find |
| 37 | +/// |
| 38 | +/// # Returns |
| 39 | +/// |
| 40 | +/// An iterator of `NodeIndex` values representing the immediate successor operation nodes. |
| 41 | +fn get_next_gate(dag: &DAGCircuit, node_index: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ { |
| 42 | + dag.quantum_successors(node_index) |
| 43 | + .filter(|&idx| matches!(dag[idx], NodeType::Operation(_))) |
| 44 | +} |
| 45 | + |
| 46 | +/// Update the start time of the current node to satisfy alignment constraints. |
| 47 | +/// Immediate successors are pushed back to avoid overlap and will be processed later. |
| 48 | +/// |
| 49 | +/// Note: |
| 50 | +/// This logic assumes that all bits in the qregs and cregs synchronously start and end, |
| 51 | +/// i.e. occupy the same time slot, but qregs and cregs can take different time slots |
| 52 | +/// due to classical I/O latencies. |
| 53 | +/// |
| 54 | +/// # Args: |
| 55 | +/// * `py` - Python interpreter reference for PyO3 operations |
| 56 | +/// * `dag` - Reference to the DAGCircuit to be rescheduled with constraints |
| 57 | +/// * `node_index` - Index of the current node to be processed |
| 58 | +/// * `node_start_time` - Mutable Python dictionary mapping node indices to start times |
| 59 | +/// * `clbit_write_latency` - Additional latency for classical bit write operations |
| 60 | +/// * `pulse_align` - Alignment constraint for gate operations (in dt units) |
| 61 | +/// * `acquire_align` - Alignment constraint for measurement/reset operations (in dt units) |
| 62 | +/// * `target` - Optional target backend for duration information |
| 63 | +fn push_node_back( |
| 64 | + dag: &DAGCircuit, |
| 65 | + node_index: NodeIndex, |
| 66 | + node_start_time: &mut IndexMap<NodeIndex<u32>, u64, RandomState>, |
| 67 | + clbit_write_latency: u32, |
| 68 | + pulse_align: u32, |
| 69 | + acquire_align: u32, |
| 70 | + target: Option<&Target>, |
| 71 | +) -> PyResult<()> { |
| 72 | + let NodeType::Operation(op) = &dag[node_index] else { |
| 73 | + unreachable!("topological_op_nodes() should only return operations.") |
| 74 | + }; |
| 75 | + |
| 76 | + let op_view = op.op.view(); |
| 77 | + let alignment = match op_view { |
| 78 | + OperationRef::Gate(_) | OperationRef::StandardGate(_) => Some(pulse_align), |
| 79 | + OperationRef::StandardInstruction(StandardInstruction::Reset) |
| 80 | + | OperationRef::StandardInstruction(StandardInstruction::Measure) => Some(acquire_align), |
| 81 | + OperationRef::StandardInstruction(StandardInstruction::Delay(_)) => None, |
| 82 | + _ => { |
| 83 | + if !op_view.directive() { |
| 84 | + None |
| 85 | + } else { |
| 86 | + return Err(TranspilerError::new_err(format!( |
| 87 | + "Unknown operation type for '{}'.", |
| 88 | + op_view.name() |
| 89 | + ))); |
| 90 | + } |
| 91 | + } |
| 92 | + }; |
| 93 | + |
| 94 | + let mut this_t0: u64 = *node_start_time |
| 95 | + .get(&node_index) |
| 96 | + .ok_or_else(|| PyValueError::new_err("Missing value in node_start_time"))?; |
| 97 | + |
| 98 | + if let Some(alignment) = alignment { |
| 99 | + let misalignment = this_t0 % alignment as u64; |
| 100 | + let shift = if misalignment != 0 { |
| 101 | + (alignment as u64).saturating_sub(misalignment) |
| 102 | + } else { |
| 103 | + 0 |
| 104 | + }; |
| 105 | + this_t0 += shift; |
| 106 | + node_start_time |
| 107 | + .entry(node_index) |
| 108 | + .and_modify(|old_t0| *old_t0 = this_t0) |
| 109 | + .or_insert(this_t0); |
| 110 | + } |
| 111 | + |
| 112 | + let new_t1q = if let Some(target) = target { |
| 113 | + let qargs: Vec<PhysicalQubit> = dag |
| 114 | + .qargs_interner() |
| 115 | + .get(op.qubits) |
| 116 | + .iter() |
| 117 | + .map(|q| PhysicalQubit(q.index() as u32)) |
| 118 | + .collect(); |
| 119 | + let duration = target.get_duration(op.op.name(), &qargs).unwrap_or(0.0); |
| 120 | + this_t0 + duration as u64 |
| 121 | + } else if matches!( |
| 122 | + op_view, |
| 123 | + OperationRef::StandardInstruction(StandardInstruction::Delay(_)) |
| 124 | + ) { |
| 125 | + let params = op.params_view(); |
| 126 | + let param = params |
| 127 | + .first() |
| 128 | + .ok_or_else(|| PyValueError::new_err("Delay instruction missing duration parameter"))?; |
| 129 | + let duration = match param { |
| 130 | + Param::Obj(val) => { |
| 131 | + // Try to extract as different numeric types |
| 132 | + Python::attach(|py| val.bind(py).extract::<u64>()) |
| 133 | + } |
| 134 | + Param::Float(f) => Ok(*f as u64), |
| 135 | + _ => Err(TranspilerError::new_err( |
| 136 | + "The provided Delay duration is not in terms of dt.", |
| 137 | + )), |
| 138 | + }?; |
| 139 | + |
| 140 | + this_t0 + duration |
| 141 | + } else { |
| 142 | + this_t0 |
| 143 | + }; |
| 144 | + |
| 145 | + let this_qubits: HashSet<_> = dag |
| 146 | + .qargs_interner() |
| 147 | + .get(op.qubits) |
| 148 | + .iter() |
| 149 | + .map(|q| q.index()) |
| 150 | + .collect(); |
| 151 | + |
| 152 | + // Handle classical bits based on operation type |
| 153 | + let (new_t1c, this_clbits) = if matches!( |
| 154 | + op_view, |
| 155 | + OperationRef::StandardInstruction(StandardInstruction::Measure) |
| 156 | + | OperationRef::StandardInstruction(StandardInstruction::Reset) |
| 157 | + ) { |
| 158 | + // creg access ends at the end of instruction |
| 159 | + let new_t1c = Some(new_t1q); |
| 160 | + let this_clbits: HashSet<_> = dag |
| 161 | + .cargs_interner() |
| 162 | + .get(op.clbits) |
| 163 | + .iter() |
| 164 | + .map(|c| c.index()) |
| 165 | + .collect(); |
| 166 | + (new_t1c, this_clbits) |
| 167 | + } else { |
| 168 | + (None, HashSet::new()) |
| 169 | + }; |
| 170 | + // Check immediate successors for overlap |
| 171 | + for next_node_index in get_next_gate(dag, node_index) { |
| 172 | + // Get the next node |
| 173 | + let NodeType::Operation(next_node) = &dag[next_node_index] else { |
| 174 | + unreachable!("topological_op_nodes() should only return operations.") |
| 175 | + }; |
| 176 | + |
| 177 | + // Compute next node start time separately for qreg and creg |
| 178 | + let next_t0q: u64 = node_start_time |
| 179 | + .get(&next_node_index) |
| 180 | + .copied() |
| 181 | + .expect("Expected value in node_start_time for next_node_index"); |
| 182 | + |
| 183 | + let next_qubits: HashSet<_> = dag |
| 184 | + .qargs_interner() |
| 185 | + .get(next_node.qubits) |
| 186 | + .iter() |
| 187 | + .map(|q| q.index()) |
| 188 | + .collect(); |
| 189 | + |
| 190 | + let next_op_view = next_node.op.view(); |
| 191 | + let (next_t0c, next_clbits) = if matches!( |
| 192 | + next_op_view, |
| 193 | + OperationRef::StandardInstruction(StandardInstruction::Measure) |
| 194 | + | OperationRef::StandardInstruction(StandardInstruction::Reset) |
| 195 | + ) { |
| 196 | + // creg access starts after write latency |
| 197 | + let next_t0c = Some(next_t0q + clbit_write_latency as u64); |
| 198 | + let next_clbits: HashSet<_> = dag |
| 199 | + .cargs_interner() |
| 200 | + .get(next_node.clbits) |
| 201 | + .iter() |
| 202 | + .map(|c| c.index()) |
| 203 | + .collect(); |
| 204 | + (next_t0c, next_clbits) |
| 205 | + } else { |
| 206 | + (None, HashSet::new()) |
| 207 | + }; |
| 208 | + |
| 209 | + // Compute overlap if there is qubits overlap |
| 210 | + let qreg_overlap = if !this_qubits.is_disjoint(&next_qubits) { |
| 211 | + new_t1q - next_t0q |
| 212 | + } else { |
| 213 | + 0 |
| 214 | + }; |
| 215 | + |
| 216 | + // Compute overlap if there is clbits overlap |
| 217 | + let creg_overlap = if !this_clbits.is_empty() |
| 218 | + && !next_clbits.is_empty() |
| 219 | + && !this_clbits.is_disjoint(&next_clbits) |
| 220 | + { |
| 221 | + if let (Some(t1c), Some(t0c)) = (new_t1c, next_t0c) { |
| 222 | + t1c - t0c |
| 223 | + } else { |
| 224 | + 0 |
| 225 | + } |
| 226 | + } else { |
| 227 | + 0 |
| 228 | + }; |
| 229 | + |
| 230 | + // Shift next node if there is finite overlap in either qubits or clbits |
| 231 | + let overlap = qreg_overlap.max(creg_overlap); |
| 232 | + if overlap > 0 { |
| 233 | + let new_start_time = next_t0q + overlap; |
| 234 | + node_start_time |
| 235 | + .entry(next_node_index) |
| 236 | + .and_modify(|old_start| *old_start = new_start_time) |
| 237 | + .or_insert(new_start_time); |
| 238 | + } |
| 239 | + } |
| 240 | + Ok(()) |
| 241 | +} |
| 242 | + |
| 243 | +#[pyfunction] |
| 244 | +#[pyo3(name="constrained_reschedule", signature=(dag, node_start_time, clbit_write_latency, acquire_align, pulse_align, target))] |
| 245 | +pub fn py_run_constrained_reschedule( |
| 246 | + dag: &DAGCircuit, |
| 247 | + mut node_start_time: PyNodeDurations, |
| 248 | + clbit_write_latency: u32, |
| 249 | + acquire_align: u32, |
| 250 | + pulse_align: u32, |
| 251 | + target: Option<&Target>, |
| 252 | +) -> PyResult<PyNodeDurations> { |
| 253 | + let NodeDurations::Dt(durations) = &mut *node_start_time else { |
| 254 | + return Err(TranspilerError::new_err( |
| 255 | + "The durations provided have not been properly converted to 'dt' units.", |
| 256 | + )); |
| 257 | + }; |
| 258 | + run_constrained_reschedule( |
| 259 | + dag, |
| 260 | + durations, |
| 261 | + clbit_write_latency, |
| 262 | + acquire_align, |
| 263 | + pulse_align, |
| 264 | + target, |
| 265 | + )?; |
| 266 | + Ok(node_start_time) |
| 267 | +} |
| 268 | + |
| 269 | +pub fn run_constrained_reschedule( |
| 270 | + dag: &DAGCircuit, |
| 271 | + node_start_time: &mut IndexMap<NodeIndex<u32>, u64, RandomState>, |
| 272 | + clbit_write_latency: u32, |
| 273 | + acquire_align: u32, |
| 274 | + pulse_align: u32, |
| 275 | + target: Option<&Target>, |
| 276 | +) -> PyResult<()> { |
| 277 | + for node_index in dag.topological_op_nodes(false) { |
| 278 | + let start_time = node_start_time.get(&node_index); |
| 279 | + let val = *start_time.ok_or_else(|| { |
| 280 | + TranspilerError::new_err(format!( |
| 281 | + "Missing start time for node {}. Run scheduler again.", |
| 282 | + node_index.index() |
| 283 | + )) |
| 284 | + })?; |
| 285 | + |
| 286 | + if val == 0 { |
| 287 | + continue; |
| 288 | + } |
| 289 | + |
| 290 | + push_node_back( |
| 291 | + dag, |
| 292 | + node_index, |
| 293 | + node_start_time, |
| 294 | + clbit_write_latency, |
| 295 | + acquire_align, |
| 296 | + pulse_align, |
| 297 | + target, |
| 298 | + )?; |
| 299 | + } |
| 300 | + |
| 301 | + Ok(()) |
| 302 | +} |
| 303 | + |
| 304 | +pub fn constrained_reschedule_mod(m: &Bound<PyModule>) -> PyResult<()> { |
| 305 | + m.add_wrapped(wrap_pyfunction!(py_run_constrained_reschedule))?; |
| 306 | + Ok(()) |
| 307 | +} |
0 commit comments