-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Improve commutation checking of Pauli product rotations and measurements #15815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
5d2d7d0
dff5668
49960e1
0dde194
4cd1ff4
789e1c6
de13108
087362f
08f9fa5
1beefa4
e45c48a
6cbce16
b654130
d889df0
06cbb55
f4a9df6
b13a58d
76a8c74
55e93d1
c90835c
b3e6456
a685517
9f7a432
fb03d53
c2c3ca5
a669772
a67a060
d13a51a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -240,7 +240,8 @@ fn try_extract_op_from_ppr( | |
| Some(out.compose_map(&local, |i| qubits[i as usize].0)) | ||
| } | ||
|
|
||
| fn try_pauli_generator( | ||
| /// Attempt to extract generator of a Pauli-based gate in the form of a sparse observable. | ||
| fn try_sparse_observable_generator_for_pauli_based( | ||
| operation: &OperationRef, | ||
| qubits: &[Qubit], | ||
| num_qubits: u32, | ||
|
|
@@ -254,7 +255,8 @@ fn try_pauli_generator( | |
| } | ||
| } | ||
|
|
||
| fn try_standard_gate_generator( | ||
| /// Attemp to extract generator of a standard gate in the form of a sparse observable. | ||
| fn try_sparse_observable_generator_for_standard_gate( | ||
|
alexanderivrii marked this conversation as resolved.
|
||
| operation: &OperationRef, | ||
| params: &[Param], | ||
| qubits: &[Qubit], | ||
|
|
@@ -269,6 +271,18 @@ fn try_standard_gate_generator( | |
| None | ||
| } | ||
|
|
||
| /// Attempt to extract generator of a Pauli-based gate in the form of a single Pauli. | ||
| /// When successful, return the generator in the (Z, X) form. | ||
| pub fn try_pauli_generator_for_pauli_based<'a>( | ||
| operation: &'a OperationRef, | ||
| ) -> Option<(&'a Vec<bool>, &'a Vec<bool>)> { | ||
|
jan-an marked this conversation as resolved.
|
||
| match operation { | ||
| OperationRef::PauliProductRotation(ppr) => Some((&ppr.z, &ppr.x)), | ||
| OperationRef::PauliProductMeasurement(ppm) => Some((&ppm.z, &ppm.x)), | ||
| _ => None, | ||
| } | ||
| } | ||
|
|
||
| fn get_bits_from_py<T>( | ||
| py_bits1: &Bound<'_, PyTuple>, | ||
| py_bits2: &Bound<'_, PyTuple>, | ||
|
|
@@ -307,6 +321,7 @@ pub struct CommutationChecker { | |
| current_cache_entries: usize, | ||
| #[pyo3(get)] | ||
| gates: Option<HashSet<String>>, | ||
| scratch_map: HashMap<usize, usize>, | ||
|
Cryoris marked this conversation as resolved.
|
||
| } | ||
|
|
||
| #[pymethods] | ||
|
|
@@ -460,6 +475,7 @@ impl CommutationChecker { | |
| cache_max_entries, | ||
| current_cache_entries: 0, | ||
| gates, | ||
| scratch_map: HashMap::new(), | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -542,21 +558,53 @@ impl CommutationChecker { | |
| _ => (), | ||
| }; | ||
|
|
||
| // Sort the arguments, such that `op2` always is the larger one. | ||
| let reversed = (op1.num_qubits(), op1.name().len(), op1.name()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any specific reason we are using the operation name lengths/ name while sorting?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly, I don't know, this code existed before, I only made it a bit more concise. |
||
| > (op2.num_qubits(), op2.name().len(), op2.name()); | ||
|
|
||
| let (op1, op2, params1, params2, qargs1, qargs2) = if reversed { | ||
| (op2, op1, params2, params1, qargs2, qargs1) | ||
| } else { | ||
| (op1, op2, params1, params2, qargs1, qargs2) | ||
| }; | ||
|
|
||
| // Handle commutations using Pauli-based generators. | ||
| if let (Some((z1, x1)), Some((z2, x2))) = ( | ||
| try_pauli_generator_for_pauli_based(op1), | ||
| try_pauli_generator_for_pauli_based(op2), | ||
| ) { | ||
| self.scratch_map.clear(); | ||
| for (i, &q) in qargs1.iter().enumerate() { | ||
| self.scratch_map.insert(q.index(), i); | ||
| } | ||
|
Cryoris marked this conversation as resolved.
|
||
| let mut parity = false; | ||
| for (j, &q) in qargs2.iter().enumerate() { | ||
| if let Some(&i) = self.scratch_map.get(&q.index()) { | ||
| parity ^= (x1[i] && z2[j]) ^ (z1[i] && x2[j]); | ||
| } | ||
| } | ||
| return Ok(!parity); | ||
| } | ||
|
|
||
| // Handle commutations between Pauli-based gates among themselves, and with standard gates | ||
| // TODO Support trivial commutations of standard gates with identities in the Paulis | ||
| let size = qargs1.iter().chain(qargs2.iter()).max().unwrap().0 + 1; | ||
| let maybe_pauli1 = try_pauli_generator(op1, qargs1, size); | ||
| let maybe_pauli2 = try_pauli_generator(op2, qargs2, size); | ||
| let maybe_pauli1 = try_sparse_observable_generator_for_pauli_based(op1, qargs1, size); | ||
| let maybe_pauli2 = try_sparse_observable_generator_for_pauli_based(op2, qargs2, size); | ||
|
|
||
| match (maybe_pauli1, maybe_pauli2) { | ||
| (None, None) => (), // No gate is Pauli-based, continue | ||
| (None, Some(pauli2)) => { | ||
| if let Some(pauli1) = try_standard_gate_generator(op1, params1, qargs1, size) { | ||
| if let Some(pauli1) = | ||
| try_sparse_observable_generator_for_standard_gate(op1, params1, qargs1, size) | ||
| { | ||
| return Ok(pauli1.commutes(&pauli2, tol)); | ||
| } | ||
| } | ||
| (Some(pauli1), None) => { | ||
| if let Some(pauli2) = try_standard_gate_generator(op2, params2, qargs2, size) { | ||
| if let Some(pauli2) = | ||
| try_sparse_observable_generator_for_standard_gate(op2, params2, qargs2, size) | ||
| { | ||
| return Ok(pauli1.commutes(&pauli2, tol)); | ||
| } | ||
| } | ||
|
|
@@ -568,24 +616,6 @@ impl CommutationChecker { | |
| return Ok(false); | ||
| } | ||
|
|
||
| // Sort the arguments, such that `second_op` always is the larger one. | ||
| let reversed = if op1.num_qubits() != op2.num_qubits() { | ||
| op1.num_qubits() > op2.num_qubits() | ||
| } else { | ||
| (op1.name().len(), op1.name()) >= (op2.name().len(), op2.name()) | ||
| }; | ||
| let (first_params, second_params) = if reversed { | ||
| (params2, params1) | ||
| } else { | ||
| (params1, params2) | ||
| }; | ||
| let (first_op, second_op) = if reversed { (op2, op1) } else { (op1, op2) }; | ||
| let (first_qargs, second_qargs) = if reversed { | ||
| (qargs2, qargs1) | ||
| } else { | ||
| (qargs1, qargs2) | ||
| }; | ||
|
|
||
| // For our cache to work correctly, we require the gate's definition to only depend on the | ||
| // ``params`` attribute. This cannot be guaranteed for custom gates, so we only check | ||
| // the cache for | ||
|
|
@@ -599,70 +629,53 @@ impl CommutationChecker { | |
| false | ||
| } | ||
| }; | ||
| let check_cache = | ||
| is_cachable(first_op, first_params) && is_cachable(second_op, second_params); | ||
| let check_cache = is_cachable(op1, params1) && is_cachable(op2, params2); | ||
|
|
||
| if !check_cache { | ||
| // The arguments are sorted, so if first_qargs.len() > matrix_max_num_qubits, then | ||
| // second_qargs.len() > matrix_max_num_qubits as well. | ||
| if second_qargs.len() > matrix_max_num_qubits as usize { | ||
| // The arguments are sorted, so if qargs1.len() > matrix_max_num_qubits, then | ||
| // qargs2.len() > matrix_max_num_qubits as well. | ||
| if qargs2.len() > matrix_max_num_qubits as usize { | ||
| return Ok(false); | ||
| } | ||
| return self.commute_matmul( | ||
| first_op, | ||
| first_params, | ||
| first_qargs, | ||
| second_op, | ||
| second_params, | ||
| second_qargs, | ||
| tol, | ||
| ); | ||
| return self.commute_matmul(op1, params1, qargs1, op2, params2, qargs2, tol); | ||
| } | ||
|
|
||
| // Query commutation library | ||
| let relative_placement = get_relative_placement(first_qargs, second_qargs); | ||
| let relative_placement = get_relative_placement(qargs1, qargs2); | ||
| if let Some(is_commuting) = | ||
| self.library | ||
| .check_commutation_entries(first_op, second_op, &relative_placement) | ||
| .check_commutation_entries(op1, op2, &relative_placement) | ||
| { | ||
| return Ok(is_commuting); | ||
| } | ||
|
|
||
| // Query cache | ||
| let key1 = hashable_params(first_params)?; | ||
| let key2 = hashable_params(second_params)?; | ||
| let key1 = hashable_params(params1)?; | ||
| let key2 = hashable_params(params2)?; | ||
| if let Some(commutation_dict) = self | ||
| .cache | ||
| .get(&(first_op.name().to_string(), second_op.name().to_string())) | ||
| .get(&(op1.name().to_string(), op2.name().to_string())) | ||
| { | ||
| let hashes = (key1.clone(), key2.clone()); | ||
| if let Some(commutation) = commutation_dict.get(&(relative_placement.clone(), hashes)) { | ||
| return Ok(*commutation); | ||
| } | ||
| } | ||
|
|
||
| if second_qargs.len() > matrix_max_num_qubits as usize { | ||
| if qargs2.len() > matrix_max_num_qubits as usize { | ||
| return Ok(false); | ||
| } | ||
|
|
||
| // Perform matrix multiplication to determine commutation | ||
| let is_commuting = self.commute_matmul( | ||
| first_op, | ||
| first_params, | ||
| first_qargs, | ||
| second_op, | ||
| second_params, | ||
| second_qargs, | ||
| tol, | ||
| )?; | ||
| let is_commuting = self.commute_matmul(op1, params1, qargs1, op2, params2, qargs2, tol)?; | ||
|
|
||
| // TODO: implement a LRU cache for this | ||
| if self.current_cache_entries >= self.cache_max_entries { | ||
| self.clear_cache(); | ||
| } | ||
| // Cache results from is_commuting | ||
| self.cache | ||
| .entry((first_op.name().to_string(), second_op.name().to_string())) | ||
| .entry((op1.name().to_string(), op2.name().to_string())) | ||
| .and_modify(|entries| { | ||
| let key = (relative_placement.clone(), (key1.clone(), key2.clone())); | ||
| entries.insert(key, is_commuting); | ||
|
|
@@ -1184,6 +1197,7 @@ pub fn get_standard_commutation_checker() -> CommutationChecker { | |
| cache: HashMap::new(), | ||
| current_cache_entries: 0, | ||
| gates: None, | ||
| scratch_map: HashMap::new(), | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| --- | ||
| features_transpiler: | ||
| - | | ||
| Improved performance of :class:`.CommutationChecker` when evaluating commutations | ||
| between :class:`.PauliProductRotationGate` and :class:`.PauliProductMeasurement` objects. | ||
| This is achieved by representing their generators as single-terms Paulis and performing | ||
| commutativity checks using these representations. | ||
| - | | ||
| Improved performance of the :class:`.CommutativeOptimization` transpiler pass on | ||
| circuits containing :class:`.PauliProductRotationGate` and :class:`.PauliProductMeasurement` | ||
| objects. |
Uh oh!
There was an error while loading. Please reload this page.