Skip to content

Commit aad2200

Browse files
committed
Merge remote-tracking branch 'ibm/main' into circuit-stretch
2 parents 62cf3ee + d2f4861 commit aad2200

107 files changed

Lines changed: 1159 additions & 11437 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

crates/accelerate/src/commutation_analysis.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ pub(crate) fn analyze_commutations_inner(
5454
py: Python,
5555
dag: &mut DAGCircuit,
5656
commutation_checker: &mut CommutationChecker,
57+
approximation_degree: f64,
5758
) -> PyResult<(CommutationSet, NodeIndices)> {
5859
let mut commutation_set: CommutationSet = HashMap::new();
5960
let mut node_indices: NodeIndices = HashMap::new();
@@ -102,6 +103,7 @@ pub(crate) fn analyze_commutations_inner(
102103
qargs2,
103104
cargs2,
104105
MAX_NUM_QUBITS,
106+
approximation_degree,
105107
)?;
106108
if !all_commute {
107109
break;
@@ -132,17 +134,19 @@ pub(crate) fn analyze_commutations_inner(
132134
}
133135

134136
#[pyfunction]
135-
#[pyo3(signature = (dag, commutation_checker))]
137+
#[pyo3(signature = (dag, commutation_checker, approximation_degree=1.))]
136138
pub(crate) fn analyze_commutations(
137139
py: Python,
138140
dag: &mut DAGCircuit,
139141
commutation_checker: &mut CommutationChecker,
142+
approximation_degree: f64,
140143
) -> PyResult<Py<PyDict>> {
141144
// This returns two HashMaps:
142145
// * The commuting nodes per wire: {wire: [commuting_nodes_1, commuting_nodes_2, ...]}
143146
// * The index in which commutation set a given node is located on a wire: {(node, wire): index}
144147
// The Python dict will store both of these dictionaries in one.
145-
let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?;
148+
let (commutation_set, node_indices) =
149+
analyze_commutations_inner(py, dag, commutation_checker, approximation_degree)?;
146150

147151
let out_dict = PyDict::new(py);
148152

crates/accelerate/src/commutation_cancellation.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,13 @@ struct CancellationSetKey {
5757
}
5858

5959
#[pyfunction]
60-
#[pyo3(signature = (dag, commutation_checker, basis_gates=None))]
60+
#[pyo3(signature = (dag, commutation_checker, basis_gates=None, approximation_degree=1.))]
6161
pub(crate) fn cancel_commutations(
6262
py: Python,
6363
dag: &mut DAGCircuit,
6464
commutation_checker: &mut CommutationChecker,
6565
basis_gates: Option<HashSet<String>>,
66+
approximation_degree: f64,
6667
) -> PyResult<()> {
6768
let basis: HashSet<String> = if let Some(basis) = basis_gates {
6869
basis
@@ -97,7 +98,8 @@ pub(crate) fn cancel_commutations(
9798
sec_commutation_set_id), the value is the list gates that share the same gate type,
9899
qubits and commutation sets.
99100
*/
100-
let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?;
101+
let (commutation_set, node_indices) =
102+
analyze_commutations_inner(py, dag, commutation_checker, approximation_degree)?;
101103
let mut cancellation_sets: HashMap<CancellationSetKey, Vec<NodeIndex>> = HashMap::new();
102104

103105
(0..dag.num_qubits() as u32).for_each(|qubit| {

crates/accelerate/src/commutation_checker.rs

Lines changed: 82 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use hashbrown::{HashMap, HashSet};
1414
use ndarray::linalg::kron;
1515
use ndarray::Array2;
1616
use num_complex::Complex64;
17+
use num_complex::ComplexFloat;
1718
use once_cell::sync::Lazy;
1819
use smallvec::SmallVec;
1920

@@ -34,6 +35,7 @@ use qiskit_circuit::operations::{
3435
};
3536
use qiskit_circuit::{BitType, Clbit, Qubit};
3637

38+
use crate::gate_metrics;
3739
use crate::unitary_compose;
3840
use crate::QiskitError;
3941

@@ -54,48 +56,30 @@ static SUPPORTED_OP: Lazy<HashSet<&str>> = Lazy::new(|| {
5456
// and their pi-periodicity. Here we mean a gate is n-pi periodic, if for angles that are
5557
// multiples of n*pi, the gate is equal to the identity up to a global phase.
5658
// E.g. RX is generated by X and 2-pi periodic, while CRX is generated by CX and 4-pi periodic.
57-
static SUPPORTED_ROTATIONS: Lazy<HashMap<&str, (u8, Option<OperationRef>)>> = Lazy::new(|| {
59+
static SUPPORTED_ROTATIONS: Lazy<HashMap<&str, Option<OperationRef>>> = Lazy::new(|| {
5860
HashMap::from([
59-
(
60-
"rx",
61-
(2, Some(OperationRef::StandardGate(StandardGate::XGate))),
62-
),
63-
(
64-
"ry",
65-
(2, Some(OperationRef::StandardGate(StandardGate::YGate))),
66-
),
67-
(
68-
"rz",
69-
(2, Some(OperationRef::StandardGate(StandardGate::ZGate))),
70-
),
71-
(
72-
"p",
73-
(2, Some(OperationRef::StandardGate(StandardGate::ZGate))),
74-
),
75-
(
76-
"u1",
77-
(2, Some(OperationRef::StandardGate(StandardGate::ZGate))),
78-
),
79-
("rxx", (2, None)), // None means the gate is in the commutation dictionary
80-
("ryy", (2, None)),
81-
("rzx", (2, None)),
82-
("rzz", (2, None)),
61+
("rx", Some(OperationRef::StandardGate(StandardGate::XGate))),
62+
("ry", Some(OperationRef::StandardGate(StandardGate::YGate))),
63+
("rz", Some(OperationRef::StandardGate(StandardGate::ZGate))),
64+
("p", Some(OperationRef::StandardGate(StandardGate::ZGate))),
65+
("u1", Some(OperationRef::StandardGate(StandardGate::ZGate))),
66+
("rxx", None), // None means the gate is in the commutation dictionary
67+
("ryy", None),
68+
("rzx", None),
69+
("rzz", None),
8370
(
8471
"crx",
85-
(4, Some(OperationRef::StandardGate(StandardGate::CXGate))),
72+
Some(OperationRef::StandardGate(StandardGate::CXGate)),
8673
),
8774
(
8875
"cry",
89-
(4, Some(OperationRef::StandardGate(StandardGate::CYGate))),
76+
Some(OperationRef::StandardGate(StandardGate::CYGate)),
9077
),
9178
(
9279
"crz",
93-
(4, Some(OperationRef::StandardGate(StandardGate::CZGate))),
94-
),
95-
(
96-
"cp",
97-
(2, Some(OperationRef::StandardGate(StandardGate::CZGate))),
80+
Some(OperationRef::StandardGate(StandardGate::CZGate)),
9881
),
82+
("cp", Some(OperationRef::StandardGate(StandardGate::CZGate))),
9983
])
10084
});
10185

@@ -155,13 +139,14 @@ impl CommutationChecker {
155139
}
156140
}
157141

158-
#[pyo3(signature=(op1, op2, max_num_qubits=3))]
142+
#[pyo3(signature=(op1, op2, max_num_qubits=3, approximation_degree=1.))]
159143
fn commute_nodes(
160144
&mut self,
161145
py: Python,
162146
op1: &DAGOpNode,
163147
op2: &DAGOpNode,
164148
max_num_qubits: u32,
149+
approximation_degree: f64,
165150
) -> PyResult<bool> {
166151
let (qargs1, qargs2) = get_bits::<Qubit>(
167152
py,
@@ -185,10 +170,11 @@ impl CommutationChecker {
185170
&qargs2,
186171
&cargs2,
187172
max_num_qubits,
173+
approximation_degree,
188174
)
189175
}
190176

191-
#[pyo3(signature=(op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits=3))]
177+
#[pyo3(signature=(op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits=3, approximation_degree=1.))]
192178
#[allow(clippy::too_many_arguments)]
193179
fn commute(
194180
&mut self,
@@ -200,6 +186,7 @@ impl CommutationChecker {
200186
qargs2: Option<&Bound<PySequence>>,
201187
cargs2: Option<&Bound<PySequence>>,
202188
max_num_qubits: u32,
189+
approximation_degree: f64,
203190
) -> PyResult<bool> {
204191
let qargs1 = qargs1.map_or_else(|| Ok(PyTuple::empty(py)), PySequenceMethods::to_tuple)?;
205192
let cargs1 = cargs1.map_or_else(|| Ok(PyTuple::empty(py)), PySequenceMethods::to_tuple)?;
@@ -220,6 +207,7 @@ impl CommutationChecker {
220207
&qargs2,
221208
&cargs2,
222209
max_num_qubits,
210+
approximation_degree,
223211
)
224212
}
225213

@@ -288,20 +276,20 @@ impl CommutationChecker {
288276
qargs2: &[Qubit],
289277
cargs2: &[Clbit],
290278
max_num_qubits: u32,
279+
approximation_degree: f64,
291280
) -> PyResult<bool> {
292-
// relative and absolute tolerance used to (1) check whether rotation gates commute
293-
// trivially (i.e. the rotation angle is so small we assume it commutes) and (2) define
294-
// comparison for the matrix-based commutation checks
295-
let rtol = 1e-5;
296-
let atol = 1e-8;
281+
// If the average gate infidelity is below this tolerance, they commute. The tolerance
282+
// is set to max(1e-12, 1 - approximation_degree), to account for roundoffs and for
283+
// consistency with other places in Qiskit.
284+
let tol = 1e-12_f64.max(1. - approximation_degree);
297285

298286
// if we have rotation gates, we attempt to map them to their generators, for example
299287
// RX -> X or CPhase -> CZ
300-
let (op1, params1, trivial1) = map_rotation(op1, params1, rtol);
288+
let (op1, params1, trivial1) = map_rotation(op1, params1, tol);
301289
if trivial1 {
302290
return Ok(true);
303291
}
304-
let (op2, params2, trivial2) = map_rotation(op2, params2, rtol);
292+
let (op2, params2, trivial2) = map_rotation(op2, params2, tol);
305293
if trivial2 {
306294
return Ok(true);
307295
}
@@ -367,8 +355,7 @@ impl CommutationChecker {
367355
second_op,
368356
second_params,
369357
second_qargs,
370-
rtol,
371-
atol,
358+
tol,
372359
);
373360
}
374361

@@ -403,8 +390,7 @@ impl CommutationChecker {
403390
second_op,
404391
second_params,
405392
second_qargs,
406-
rtol,
407-
atol,
393+
tol,
408394
)?;
409395

410396
// TODO: implement a LRU cache for this
@@ -439,8 +425,7 @@ impl CommutationChecker {
439425
second_op: &OperationRef,
440426
second_params: &[Param],
441427
second_qargs: &[Qubit],
442-
rtol: f64,
443-
atol: f64,
428+
tol: f64,
444429
) -> PyResult<bool> {
445430
// Compute relative positioning of qargs of the second gate to the first gate.
446431
// Since the qargs come out the same BitData, we already know there are no accidential
@@ -481,81 +466,49 @@ impl CommutationChecker {
481466
None => return Ok(false),
482467
};
483468

484-
if first_qarg == second_qarg {
485-
match first_qarg.len() {
486-
1 => Ok(unitary_compose::commute_1q(
487-
&first_mat.view(),
488-
&second_mat.view(),
489-
rtol,
490-
atol,
491-
)),
492-
2 => Ok(unitary_compose::commute_2q(
493-
&first_mat.view(),
494-
&second_mat.view(),
495-
&[Qubit(0), Qubit(1)],
496-
rtol,
497-
atol,
498-
)),
499-
_ => Ok(unitary_compose::allclose(
500-
&second_mat.dot(&first_mat).view(),
501-
&first_mat.dot(&second_mat).view(),
502-
rtol,
503-
atol,
504-
)),
505-
}
469+
// TODO Optimize this bit to avoid unnecessary Kronecker products:
470+
// 1. We currently sort the operations for the cache by operation size, putting the
471+
// *smaller* operation first: (smaller op, larger op)
472+
// 2. This code here expands the first op to match the second -- hence we always
473+
// match the operator sizes.
474+
// This whole extension logic could be avoided since we know the second one is larger.
475+
let extra_qarg2 = num_qubits - first_qarg.len() as u32;
476+
let first_mat = if extra_qarg2 > 0 {
477+
let id_op = Array2::<Complex64>::eye(usize::pow(2, extra_qarg2));
478+
kron(&id_op, &first_mat)
506479
} else {
507-
// TODO Optimize this bit to avoid unnecessary Kronecker products:
508-
// 1. We currently sort the operations for the cache by operation size, putting the
509-
// *smaller* operation first: (smaller op, larger op)
510-
// 2. This code here expands the first op to match the second -- hence we always
511-
// match the operator sizes.
512-
// This whole extension logic could be avoided since we know the second one is larger.
513-
let extra_qarg2 = num_qubits - first_qarg.len() as u32;
514-
let first_mat = if extra_qarg2 > 0 {
515-
let id_op = Array2::<Complex64>::eye(usize::pow(2, extra_qarg2));
516-
kron(&id_op, &first_mat)
517-
} else {
518-
first_mat
519-
};
520-
521-
// the 1 qubit case cannot happen, since that would already have been captured
522-
// by the previous if clause; first_qarg == second_qarg (if they overlap they must
523-
// be the same)
524-
if num_qubits == 2 {
525-
return Ok(unitary_compose::commute_2q(
526-
&first_mat.view(),
527-
&second_mat.view(),
528-
&second_qarg,
529-
rtol,
530-
atol,
531-
));
532-
};
480+
first_mat
481+
};
533482

534-
let op12 = match unitary_compose::compose(
535-
&first_mat.view(),
536-
&second_mat.view(),
537-
&second_qarg,
538-
false,
539-
) {
540-
Ok(matrix) => matrix,
541-
Err(e) => return Err(PyRuntimeError::new_err(e)),
542-
};
543-
let op21 = match unitary_compose::compose(
544-
&first_mat.view(),
545-
&second_mat.view(),
546-
&second_qarg,
547-
true,
548-
) {
549-
Ok(matrix) => matrix,
550-
Err(e) => return Err(PyRuntimeError::new_err(e)),
551-
};
552-
Ok(unitary_compose::allclose(
553-
&op12.view(),
554-
&op21.view(),
555-
rtol,
556-
atol,
557-
))
558-
}
483+
// the 1 qubit case cannot happen, since that would already have been captured
484+
// by the previous if clause; first_qarg == second_qarg (if they overlap they must
485+
// be the same)
486+
let op12 = match unitary_compose::compose(
487+
&first_mat.view(),
488+
&second_mat.view(),
489+
&second_qarg,
490+
false,
491+
) {
492+
Ok(matrix) => matrix,
493+
Err(e) => return Err(PyRuntimeError::new_err(e)),
494+
};
495+
let op21 = match unitary_compose::compose(
496+
&first_mat.view(),
497+
&second_mat.view(),
498+
&second_qarg,
499+
true,
500+
) {
501+
Ok(matrix) => matrix,
502+
Err(e) => return Err(PyRuntimeError::new_err(e)),
503+
};
504+
let (fid, phase) = gate_metrics::gate_fidelity(&op12.view(), &op21.view(), None);
505+
506+
// we consider the gates as commuting if the process fidelity of
507+
// AB (BA)^\dagger is approximately the identity and there is no global phase difference
508+
// let dim = op12.ncols() as f64;
509+
// let matrix_tol = tol * dim.powi(2);
510+
let matrix_tol = tol;
511+
Ok(phase.abs() <= tol && (1.0 - fid).abs() <= matrix_tol)
559512
}
560513

561514
fn clear_cache(&mut self) {
@@ -652,13 +605,19 @@ fn map_rotation<'a>(
652605
) -> (&'a OperationRef<'a>, &'a [Param], bool) {
653606
let name = op.name();
654607

655-
if let Some((pi_multiple, generator)) = SUPPORTED_ROTATIONS.get(name) {
608+
if let Some(generator) = SUPPORTED_ROTATIONS.get(name) {
656609
// If the rotation angle is below the tolerance, the gate is assumed to
657610
// commute with everything, and we simply return the operation with the flag that
658611
// it commutes trivially.
659612
if let Param::Float(angle) = params[0] {
660-
let periodicity = (*pi_multiple as f64) * ::std::f64::consts::PI;
661-
if (angle % periodicity).abs() < tol {
613+
let gate = op
614+
.standard_gate()
615+
.expect("Supported gates are standard gates");
616+
let (tr_over_dim, dim) = gate_metrics::rotation_trace_and_dim(gate, angle)
617+
.expect("All rotation should be covered at this point");
618+
let gate_fidelity = tr_over_dim.abs().powi(2);
619+
let process_fidelity = (dim * gate_fidelity + 1.) / (dim + 1.);
620+
if (1. - process_fidelity).abs() <= tol {
662621
return (op, params, true);
663622
};
664623
};

0 commit comments

Comments
 (0)