@@ -14,6 +14,7 @@ use hashbrown::{HashMap, HashSet};
1414use ndarray:: linalg:: kron;
1515use ndarray:: Array2 ;
1616use num_complex:: Complex64 ;
17+ use num_complex:: ComplexFloat ;
1718use once_cell:: sync:: Lazy ;
1819use smallvec:: SmallVec ;
1920
@@ -34,6 +35,7 @@ use qiskit_circuit::operations::{
3435} ;
3536use qiskit_circuit:: { BitType , Clbit , Qubit } ;
3637
38+ use crate :: gate_metrics;
3739use crate :: unitary_compose;
3840use crate :: QiskitError ;
3941
@@ -54,48 +56,30 @@ static SUPPORTED_OP: Lazy<HashSet<&str>> = Lazy::new(|| {
5456// and their pi-periodicity. Here we mean a gate is n-pi periodic, if for angles that are
5557// multiples of n*pi, the gate is equal to the identity up to a global phase.
5658// E.g. RX is generated by X and 2-pi periodic, while CRX is generated by CX and 4-pi periodic.
57- static SUPPORTED_ROTATIONS : Lazy < HashMap < & str , ( u8 , Option < OperationRef > ) > > = Lazy :: new ( || {
59+ static SUPPORTED_ROTATIONS : Lazy < HashMap < & str , Option < OperationRef > > > = Lazy :: new ( || {
5860 HashMap :: from ( [
59- (
60- "rx" ,
61- ( 2 , Some ( OperationRef :: StandardGate ( StandardGate :: XGate ) ) ) ,
62- ) ,
63- (
64- "ry" ,
65- ( 2 , Some ( OperationRef :: StandardGate ( StandardGate :: YGate ) ) ) ,
66- ) ,
67- (
68- "rz" ,
69- ( 2 , Some ( OperationRef :: StandardGate ( StandardGate :: ZGate ) ) ) ,
70- ) ,
71- (
72- "p" ,
73- ( 2 , Some ( OperationRef :: StandardGate ( StandardGate :: ZGate ) ) ) ,
74- ) ,
75- (
76- "u1" ,
77- ( 2 , Some ( OperationRef :: StandardGate ( StandardGate :: ZGate ) ) ) ,
78- ) ,
79- ( "rxx" , ( 2 , None ) ) , // None means the gate is in the commutation dictionary
80- ( "ryy" , ( 2 , None ) ) ,
81- ( "rzx" , ( 2 , None ) ) ,
82- ( "rzz" , ( 2 , None ) ) ,
61+ ( "rx" , Some ( OperationRef :: StandardGate ( StandardGate :: XGate ) ) ) ,
62+ ( "ry" , Some ( OperationRef :: StandardGate ( StandardGate :: YGate ) ) ) ,
63+ ( "rz" , Some ( OperationRef :: StandardGate ( StandardGate :: ZGate ) ) ) ,
64+ ( "p" , Some ( OperationRef :: StandardGate ( StandardGate :: ZGate ) ) ) ,
65+ ( "u1" , Some ( OperationRef :: StandardGate ( StandardGate :: ZGate ) ) ) ,
66+ ( "rxx" , None ) , // None means the gate is in the commutation dictionary
67+ ( "ryy" , None ) ,
68+ ( "rzx" , None ) ,
69+ ( "rzz" , None ) ,
8370 (
8471 "crx" ,
85- ( 4 , Some ( OperationRef :: StandardGate ( StandardGate :: CXGate ) ) ) ,
72+ Some ( OperationRef :: StandardGate ( StandardGate :: CXGate ) ) ,
8673 ) ,
8774 (
8875 "cry" ,
89- ( 4 , Some ( OperationRef :: StandardGate ( StandardGate :: CYGate ) ) ) ,
76+ Some ( OperationRef :: StandardGate ( StandardGate :: CYGate ) ) ,
9077 ) ,
9178 (
9279 "crz" ,
93- ( 4 , Some ( OperationRef :: StandardGate ( StandardGate :: CZGate ) ) ) ,
94- ) ,
95- (
96- "cp" ,
97- ( 2 , Some ( OperationRef :: StandardGate ( StandardGate :: CZGate ) ) ) ,
80+ Some ( OperationRef :: StandardGate ( StandardGate :: CZGate ) ) ,
9881 ) ,
82+ ( "cp" , Some ( OperationRef :: StandardGate ( StandardGate :: CZGate ) ) ) ,
9983 ] )
10084} ) ;
10185
@@ -155,13 +139,14 @@ impl CommutationChecker {
155139 }
156140 }
157141
158- #[ pyo3( signature=( op1, op2, max_num_qubits=3 ) ) ]
142+ #[ pyo3( signature=( op1, op2, max_num_qubits=3 , approximation_degree= 1. ) ) ]
159143 fn commute_nodes (
160144 & mut self ,
161145 py : Python ,
162146 op1 : & DAGOpNode ,
163147 op2 : & DAGOpNode ,
164148 max_num_qubits : u32 ,
149+ approximation_degree : f64 ,
165150 ) -> PyResult < bool > {
166151 let ( qargs1, qargs2) = get_bits :: < Qubit > (
167152 py,
@@ -185,10 +170,11 @@ impl CommutationChecker {
185170 & qargs2,
186171 & cargs2,
187172 max_num_qubits,
173+ approximation_degree,
188174 )
189175 }
190176
191- #[ pyo3( signature=( op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits=3 ) ) ]
177+ #[ pyo3( signature=( op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits=3 , approximation_degree= 1. ) ) ]
192178 #[ allow( clippy:: too_many_arguments) ]
193179 fn commute (
194180 & mut self ,
@@ -200,6 +186,7 @@ impl CommutationChecker {
200186 qargs2 : Option < & Bound < PySequence > > ,
201187 cargs2 : Option < & Bound < PySequence > > ,
202188 max_num_qubits : u32 ,
189+ approximation_degree : f64 ,
203190 ) -> PyResult < bool > {
204191 let qargs1 = qargs1. map_or_else ( || Ok ( PyTuple :: empty ( py) ) , PySequenceMethods :: to_tuple) ?;
205192 let cargs1 = cargs1. map_or_else ( || Ok ( PyTuple :: empty ( py) ) , PySequenceMethods :: to_tuple) ?;
@@ -220,6 +207,7 @@ impl CommutationChecker {
220207 & qargs2,
221208 & cargs2,
222209 max_num_qubits,
210+ approximation_degree,
223211 )
224212 }
225213
@@ -288,20 +276,20 @@ impl CommutationChecker {
288276 qargs2 : & [ Qubit ] ,
289277 cargs2 : & [ Clbit ] ,
290278 max_num_qubits : u32 ,
279+ approximation_degree : f64 ,
291280 ) -> PyResult < bool > {
292- // relative and absolute tolerance used to (1) check whether rotation gates commute
293- // trivially (i.e. the rotation angle is so small we assume it commutes) and (2) define
294- // comparison for the matrix-based commutation checks
295- let rtol = 1e-5 ;
296- let atol = 1e-8 ;
281+ // If the average gate infidelity is below this tolerance, they commute. The tolerance
282+ // is set to max(1e-12, 1 - approximation_degree), to account for roundoffs and for
283+ // consistency with other places in Qiskit.
284+ let tol = 1e-12_f64 . max ( 1. - approximation_degree) ;
297285
298286 // if we have rotation gates, we attempt to map them to their generators, for example
299287 // RX -> X or CPhase -> CZ
300- let ( op1, params1, trivial1) = map_rotation ( op1, params1, rtol ) ;
288+ let ( op1, params1, trivial1) = map_rotation ( op1, params1, tol ) ;
301289 if trivial1 {
302290 return Ok ( true ) ;
303291 }
304- let ( op2, params2, trivial2) = map_rotation ( op2, params2, rtol ) ;
292+ let ( op2, params2, trivial2) = map_rotation ( op2, params2, tol ) ;
305293 if trivial2 {
306294 return Ok ( true ) ;
307295 }
@@ -367,8 +355,7 @@ impl CommutationChecker {
367355 second_op,
368356 second_params,
369357 second_qargs,
370- rtol,
371- atol,
358+ tol,
372359 ) ;
373360 }
374361
@@ -403,8 +390,7 @@ impl CommutationChecker {
403390 second_op,
404391 second_params,
405392 second_qargs,
406- rtol,
407- atol,
393+ tol,
408394 ) ?;
409395
410396 // TODO: implement a LRU cache for this
@@ -439,8 +425,7 @@ impl CommutationChecker {
439425 second_op : & OperationRef ,
440426 second_params : & [ Param ] ,
441427 second_qargs : & [ Qubit ] ,
442- rtol : f64 ,
443- atol : f64 ,
428+ tol : f64 ,
444429 ) -> PyResult < bool > {
445430 // Compute relative positioning of qargs of the second gate to the first gate.
446431 // Since the qargs come out the same BitData, we already know there are no accidential
@@ -481,81 +466,49 @@ impl CommutationChecker {
481466 None => return Ok ( false ) ,
482467 } ;
483468
484- if first_qarg == second_qarg {
485- match first_qarg. len ( ) {
486- 1 => Ok ( unitary_compose:: commute_1q (
487- & first_mat. view ( ) ,
488- & second_mat. view ( ) ,
489- rtol,
490- atol,
491- ) ) ,
492- 2 => Ok ( unitary_compose:: commute_2q (
493- & first_mat. view ( ) ,
494- & second_mat. view ( ) ,
495- & [ Qubit ( 0 ) , Qubit ( 1 ) ] ,
496- rtol,
497- atol,
498- ) ) ,
499- _ => Ok ( unitary_compose:: allclose (
500- & second_mat. dot ( & first_mat) . view ( ) ,
501- & first_mat. dot ( & second_mat) . view ( ) ,
502- rtol,
503- atol,
504- ) ) ,
505- }
469+ // TODO Optimize this bit to avoid unnecessary Kronecker products:
470+ // 1. We currently sort the operations for the cache by operation size, putting the
471+ // *smaller* operation first: (smaller op, larger op)
472+ // 2. This code here expands the first op to match the second -- hence we always
473+ // match the operator sizes.
474+ // This whole extension logic could be avoided since we know the second one is larger.
475+ let extra_qarg2 = num_qubits - first_qarg. len ( ) as u32 ;
476+ let first_mat = if extra_qarg2 > 0 {
477+ let id_op = Array2 :: < Complex64 > :: eye ( usize:: pow ( 2 , extra_qarg2) ) ;
478+ kron ( & id_op, & first_mat)
506479 } else {
507- // TODO Optimize this bit to avoid unnecessary Kronecker products:
508- // 1. We currently sort the operations for the cache by operation size, putting the
509- // *smaller* operation first: (smaller op, larger op)
510- // 2. This code here expands the first op to match the second -- hence we always
511- // match the operator sizes.
512- // This whole extension logic could be avoided since we know the second one is larger.
513- let extra_qarg2 = num_qubits - first_qarg. len ( ) as u32 ;
514- let first_mat = if extra_qarg2 > 0 {
515- let id_op = Array2 :: < Complex64 > :: eye ( usize:: pow ( 2 , extra_qarg2) ) ;
516- kron ( & id_op, & first_mat)
517- } else {
518- first_mat
519- } ;
520-
521- // the 1 qubit case cannot happen, since that would already have been captured
522- // by the previous if clause; first_qarg == second_qarg (if they overlap they must
523- // be the same)
524- if num_qubits == 2 {
525- return Ok ( unitary_compose:: commute_2q (
526- & first_mat. view ( ) ,
527- & second_mat. view ( ) ,
528- & second_qarg,
529- rtol,
530- atol,
531- ) ) ;
532- } ;
480+ first_mat
481+ } ;
533482
534- let op12 = match unitary_compose:: compose (
535- & first_mat. view ( ) ,
536- & second_mat. view ( ) ,
537- & second_qarg,
538- false ,
539- ) {
540- Ok ( matrix) => matrix,
541- Err ( e) => return Err ( PyRuntimeError :: new_err ( e) ) ,
542- } ;
543- let op21 = match unitary_compose:: compose (
544- & first_mat. view ( ) ,
545- & second_mat. view ( ) ,
546- & second_qarg,
547- true ,
548- ) {
549- Ok ( matrix) => matrix,
550- Err ( e) => return Err ( PyRuntimeError :: new_err ( e) ) ,
551- } ;
552- Ok ( unitary_compose:: allclose (
553- & op12. view ( ) ,
554- & op21. view ( ) ,
555- rtol,
556- atol,
557- ) )
558- }
483+ // the 1 qubit case cannot happen, since that would already have been captured
484+ // by the previous if clause; first_qarg == second_qarg (if they overlap they must
485+ // be the same)
486+ let op12 = match unitary_compose:: compose (
487+ & first_mat. view ( ) ,
488+ & second_mat. view ( ) ,
489+ & second_qarg,
490+ false ,
491+ ) {
492+ Ok ( matrix) => matrix,
493+ Err ( e) => return Err ( PyRuntimeError :: new_err ( e) ) ,
494+ } ;
495+ let op21 = match unitary_compose:: compose (
496+ & first_mat. view ( ) ,
497+ & second_mat. view ( ) ,
498+ & second_qarg,
499+ true ,
500+ ) {
501+ Ok ( matrix) => matrix,
502+ Err ( e) => return Err ( PyRuntimeError :: new_err ( e) ) ,
503+ } ;
504+ let ( fid, phase) = gate_metrics:: gate_fidelity ( & op12. view ( ) , & op21. view ( ) , None ) ;
505+
506+ // we consider the gates as commuting if the process fidelity of
507+ // AB (BA)^\dagger is approximately the identity and there is no global phase difference
508+ // let dim = op12.ncols() as f64;
509+ // let matrix_tol = tol * dim.powi(2);
510+ let matrix_tol = tol;
511+ Ok ( phase. abs ( ) <= tol && ( 1.0 - fid) . abs ( ) <= matrix_tol)
559512 }
560513
561514 fn clear_cache ( & mut self ) {
@@ -652,13 +605,19 @@ fn map_rotation<'a>(
652605) -> ( & ' a OperationRef < ' a > , & ' a [ Param ] , bool ) {
653606 let name = op. name ( ) ;
654607
655- if let Some ( ( pi_multiple , generator) ) = SUPPORTED_ROTATIONS . get ( name) {
608+ if let Some ( generator) = SUPPORTED_ROTATIONS . get ( name) {
656609 // If the rotation angle is below the tolerance, the gate is assumed to
657610 // commute with everything, and we simply return the operation with the flag that
658611 // it commutes trivially.
659612 if let Param :: Float ( angle) = params[ 0 ] {
660- let periodicity = ( * pi_multiple as f64 ) * :: std:: f64:: consts:: PI ;
661- if ( angle % periodicity) . abs ( ) < tol {
613+ let gate = op
614+ . standard_gate ( )
615+ . expect ( "Supported gates are standard gates" ) ;
616+ let ( tr_over_dim, dim) = gate_metrics:: rotation_trace_and_dim ( gate, angle)
617+ . expect ( "All rotation should be covered at this point" ) ;
618+ let gate_fidelity = tr_over_dim. abs ( ) . powi ( 2 ) ;
619+ let process_fidelity = ( dim * gate_fidelity + 1. ) / ( dim + 1. ) ;
620+ if ( 1. - process_fidelity) . abs ( ) <= tol {
662621 return ( op, params, true ) ;
663622 } ;
664623 } ;
0 commit comments