Skip to content

Commit 0d63a93

Browse files
mtreinishCryoris
andauthored
Remove the commutation cache from the commutation checker (#15988)
An early performance optimization we made in the commutation analysis pass in #3878 was to enable caching of the commutation relations between gates. Back then the commutation was only checked via the matrix multiplication method where we would compose a pair of gates' unitary matrices forwards and backwards and determine if the product of those forward and backward compositions were an identity. This was a fairly costly operation back then for various reasons and the cache enabled a large speedup by avoiding repeated computation unnecessarily. However, since that time > 6 years ago the code base has evolved substantially, including not relying on matrix multiplication based commutation determination as the default. Now we have a precomputed library of commuting gates and also have special handling for cases where we can know very easily whether gates commute or not. Additionally, all of this code has been ported to rust so the matrix multiplication based approach is not nearly as expensive (although it's not free either). In this new world the cache is actually doing more harm then good because maintaining the cache, adding extra lookups and hashing while constant time is not free. We've reached a point where all the complexity of the cache is no longer worth it, so this commit removes the commutation cache. One caveat is some aspects of this internal cache have leaked into the public API. Specifically the CommutationChecker class is public, and that includes a documented init argument `cache_max_entries` as well as two public methods `clear_cached_commutations` and `num_cached_entries`. I believe the original intent for these methods was either debugging the cache logic was correct (as they were used in tests) or to enable the user to manage the cache size manually if they so wished. The `cache_max_entries` argument was used to manage the total memory size for the cache to avoid using to much memory in certain applications. Since these are part of the public documented api we can't remove them without violating our stability guidelines. So instead this commit opts to just make them no-ops or in the case of the `num_cached_entries` it will always return 0 (since there are no longer any cached entries). These are all marked as deprecated in this PR to mark them for removal in 3.0. In practice there is a minimal performance difference from this change which means we don't need the extra code anymore, although in some very specific benchmarks a small speedup may be seen (those dominated by commutation checking). However, there is one case where running without a cache can be slower, in cases when there are a large number of gate pairs that involves a gate that we don't know whether it commutes or not without the matrix multiplication (i.e. it's not a known pauli, rotation gate, or in the library) and we have multiple repeated pairs of the same gate. In practice this doesn't come up very frequently because typically in a preset passmanager's workflow we have lowered to all 1q and 2q gates and that lowering involves standard gates we know how to work with in the checker. The only edge case is if there was a circuit with a large number of custom 1q or 2q gates that have matrix definitions in a circuit (which is not common). But, the asv benchmark for commutation analysis will likely show a roughly 5x slowdown with this commit. That benchmark is highlighting this edge case because it is running the pass on a random circuit with gates up to 3 qubits in width which will involve multiple repeated checks via matrix multiplication which does not come up in practice normally. Additionally, unlike in #3878 the regression being flagged is only on the 5x slowdown is on the scale of tens of ms, back in 2020 we could only have dreamed to have the CommutationAnalysis pass execute in < 100ms in asv, let alone a world where a 5x regression flagged in asv would be so quick. The entire pass was at least 2 orders of magnitude slower at that point we introduced the cache. Co-authored-by: Julien Gacon <jules.gacon@googlemail.com>
1 parent fa215e4 commit 0d63a93

5 files changed

Lines changed: 38 additions & 243 deletions

File tree

crates/transpiler/src/commutation_checker.rs

Lines changed: 7 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -296,30 +296,26 @@ where
296296
}
297297

298298
/// This is the internal structure for the Python CommutationChecker class
299-
/// It handles the actual commutation checking, cache management, and library
300-
/// lookups. It's not meant to be a public facing Python object though and only used
299+
/// It handles the actual commutation checking, and library lookups. It's
300+
/// not meant to be a public facing Python object though and only used
301301
/// internally by the Python class.
302302
#[pyclass(module = "qiskit._accelerate.commutation_checker")]
303303
pub struct CommutationChecker {
304304
library: CommutationLibrary,
305-
cache_max_entries: usize,
306-
cache: HashMap<(String, String), CommutationCacheEntry>,
307-
current_cache_entries: usize,
308305
#[pyo3(get)]
309306
gates: Option<HashSet<String>>,
310307
}
311308

312309
#[pymethods]
313310
impl CommutationChecker {
314-
#[pyo3(signature = (standard_gate_commutations=None, cache_max_entries=1_000_000, gates=None))]
311+
#[pyo3(signature = (standard_gate_commutations=None, gates=None))]
315312
#[new]
316313
fn py_new(
317314
standard_gate_commutations: Option<Bound<PyAny>>,
318-
cache_max_entries: usize,
319315
gates: Option<HashSet<String>>,
320316
) -> Self {
321317
let library = CommutationLibrary::py_new(standard_gate_commutations);
322-
CommutationChecker::new(Some(library), cache_max_entries, gates)
318+
CommutationChecker::new(Some(library), gates)
323319
}
324320

325321
#[pyo3(signature=(op1, op2, max_num_qubits=None, approximation_degree=1., matrix_max_num_qubits=3))]
@@ -388,52 +384,18 @@ impl CommutationChecker {
388384
)?)
389385
}
390386

391-
/// Return the current number of cache entries
392-
fn num_cached_entries(&self) -> usize {
393-
self.current_cache_entries
394-
}
395-
396-
/// Clear the cache
397-
fn clear_cached_commutations(&mut self) {
398-
self.clear_cache()
399-
}
400-
401387
fn __getstate__(&self, py: Python) -> PyResult<Py<PyDict>> {
402388
let out_dict = PyDict::new(py);
403-
out_dict.set_item("cache_max_entries", self.cache_max_entries)?;
404-
out_dict.set_item("current_cache_entries", self.current_cache_entries)?;
405-
let cache_dict = PyDict::new(py);
406-
for (key, value) in &self.cache {
407-
cache_dict.set_item(key, commutation_entry_to_pydict(py, value)?)?;
408-
}
409-
out_dict.set_item("cache", cache_dict)?;
410389
out_dict.set_item("library", self.library.library.clone().into_pyobject(py)?)?;
411390
out_dict.set_item("gates", self.gates.clone())?;
412391
Ok(out_dict.unbind())
413392
}
414393

415394
fn __setstate__(&mut self, py: Python, state: Py<PyAny>) -> PyResult<()> {
416395
let dict_state = state.cast_bound::<PyDict>(py)?;
417-
self.cache_max_entries = dict_state
418-
.get_item("cache_max_entries")?
419-
.unwrap()
420-
.extract()?;
421-
self.current_cache_entries = dict_state
422-
.get_item("current_cache_entries")?
423-
.unwrap()
424-
.extract()?;
425396
self.library = CommutationLibrary {
426397
library: dict_state.get_item("library")?.unwrap().extract()?,
427398
};
428-
let raw_cache: Bound<PyDict> = dict_state.get_item("cache")?.unwrap().extract()?;
429-
self.cache = HashMap::with_capacity(raw_cache.len());
430-
for (key, value) in raw_cache.iter() {
431-
let value_dict: &Bound<PyDict> = value.cast()?;
432-
self.cache.insert(
433-
key.extract()?,
434-
commutation_cache_entry_from_pydict(value_dict)?,
435-
);
436-
}
437399
self.gates = dict_state.get_item("gates")?.unwrap().extract()?;
438400
Ok(())
439401
}
@@ -444,21 +406,13 @@ impl CommutationChecker {
444406
///
445407
/// # Arguments
446408
///
447-
/// - `library`: An optional existing [CommutationLibrary] with cached entries.
448-
/// - `cache_max_entries`: The maximum size of the cache.
409+
/// - `library`: An optional existing [CommutationLibrary].
449410
/// - `gates`: An optional set of gates (by name) to check commutations for. If `None`,
450-
/// commutation is cached and checked for all gates.
451-
pub fn new(
452-
library: Option<CommutationLibrary>,
453-
cache_max_entries: usize,
454-
gates: Option<HashSet<String>>,
455-
) -> Self {
411+
/// commutation is checked for all gates.
412+
pub fn new(library: Option<CommutationLibrary>, gates: Option<HashSet<String>>) -> Self {
456413
// Initialize sets before they are used in the commutation checker
457414
CommutationChecker {
458415
library: library.unwrap_or(CommutationLibrary { library: None }),
459-
cache: HashMap::new(),
460-
cache_max_entries,
461-
current_cache_entries: 0,
462416
gates,
463417
}
464418
}
@@ -586,39 +540,6 @@ impl CommutationChecker {
586540
(qargs1, qargs2)
587541
};
588542

589-
// For our cache to work correctly, we require the gate's definition to only depend on the
590-
// ``params`` attribute. This cannot be guaranteed for custom gates, so we only check
591-
// the cache for
592-
// * gates we know are in the cache (SUPPORTED_OPS), or
593-
// * standard gates with float params (otherwise we cannot cache them)
594-
let is_cachable = |op: &OperationRef, params: &[Param]| {
595-
if let OperationRef::StandardGate(gate) = op {
596-
SUPPORTED_OP[(*gate) as usize]
597-
|| params.iter().all(|p| matches!(p, Param::Float(_)))
598-
} else {
599-
false
600-
}
601-
};
602-
let check_cache =
603-
is_cachable(first_op, first_params) && is_cachable(second_op, second_params);
604-
605-
if !check_cache {
606-
// The arguments are sorted, so if first_qargs.len() > matrix_max_num_qubits, then
607-
// second_qargs.len() > matrix_max_num_qubits as well.
608-
if second_qargs.len() > matrix_max_num_qubits as usize {
609-
return Ok(false);
610-
}
611-
return self.commute_matmul(
612-
first_op,
613-
first_params,
614-
first_qargs,
615-
second_op,
616-
second_params,
617-
second_qargs,
618-
tol,
619-
);
620-
}
621-
622543
// Query commutation library
623544
let relative_placement = get_relative_placement(first_qargs, second_qargs);
624545
if let Some(is_commuting) =
@@ -628,19 +549,6 @@ impl CommutationChecker {
628549
return Ok(is_commuting);
629550
}
630551

631-
// Query cache
632-
let key1 = hashable_params(first_params)?;
633-
let key2 = hashable_params(second_params)?;
634-
if let Some(commutation_dict) = self
635-
.cache
636-
.get(&(first_op.name().to_string(), second_op.name().to_string()))
637-
{
638-
let hashes = (key1.clone(), key2.clone());
639-
if let Some(commutation) = commutation_dict.get(&(relative_placement.clone(), hashes)) {
640-
return Ok(*commutation);
641-
}
642-
}
643-
644552
if second_qargs.len() > matrix_max_num_qubits as usize {
645553
return Ok(false);
646554
}
@@ -656,25 +564,6 @@ impl CommutationChecker {
656564
tol,
657565
)?;
658566

659-
// TODO: implement a LRU cache for this
660-
if self.current_cache_entries >= self.cache_max_entries {
661-
self.clear_cache();
662-
}
663-
// Cache results from is_commuting
664-
self.cache
665-
.entry((first_op.name().to_string(), second_op.name().to_string()))
666-
.and_modify(|entries| {
667-
let key = (relative_placement.clone(), (key1.clone(), key2.clone()));
668-
entries.insert(key, is_commuting);
669-
self.current_cache_entries += 1;
670-
})
671-
.or_insert_with(|| {
672-
let mut entries = HashMap::with_capacity(1);
673-
let key = (relative_placement, (key1, key2));
674-
entries.insert(key, is_commuting);
675-
self.current_cache_entries += 1;
676-
entries
677-
});
678567
Ok(is_commuting)
679568
}
680569

@@ -773,11 +662,6 @@ impl CommutationChecker {
773662
let matrix_tol = tol;
774663
Ok(phase.abs() <= tol && (1.0 - fid).abs() <= matrix_tol)
775664
}
776-
777-
fn clear_cache(&mut self) {
778-
self.cache.clear();
779-
self.current_cache_entries = 0;
780-
}
781665
}
782666

783667
/// A pre-check status.
@@ -1085,104 +969,11 @@ impl<'a, 'py> FromPyObject<'a, 'py> for CommutationLibraryEntry {
1085969
}
1086970
}
1087971

1088-
type CacheKey = (
1089-
SmallVec<[Option<Qubit>; 2]>,
1090-
(SmallVec<[ParameterKey; 3]>, SmallVec<[ParameterKey; 3]>),
1091-
);
1092-
1093-
type CommutationCacheEntry = HashMap<CacheKey, bool>;
1094-
1095-
fn commutation_entry_to_pydict(py: Python, entry: &CommutationCacheEntry) -> PyResult<Py<PyDict>> {
1096-
let out_dict = PyDict::new(py);
1097-
for (k, v) in entry.iter() {
1098-
let qubits = PyTuple::new(py, k.0.iter().map(|q| q.map(|t| t.0)))?;
1099-
let params0 = PyTuple::new(py, k.1.0.iter().map(|pk| pk.0))?;
1100-
let params1 = PyTuple::new(py, k.1.1.iter().map(|pk| pk.0))?;
1101-
out_dict.set_item(
1102-
PyTuple::new(py, [qubits, PyTuple::new(py, [params0, params1])?])?,
1103-
PyBool::new(py, *v),
1104-
)?;
1105-
}
1106-
Ok(out_dict.unbind())
1107-
}
1108-
1109-
fn commutation_cache_entry_from_pydict(dict: &Bound<PyDict>) -> PyResult<CommutationCacheEntry> {
1110-
let mut ret = hashbrown::HashMap::with_capacity(dict.len());
1111-
for (k, v) in dict {
1112-
let raw_key: CacheKeyRaw = k.extract()?;
1113-
let qubits = raw_key.0.iter().map(|q| q.map(Qubit)).collect();
1114-
let params0: SmallVec<_> = raw_key.1.0;
1115-
let params1: SmallVec<_> = raw_key.1.1;
1116-
let v: bool = v.extract()?;
1117-
ret.insert((qubits, (params0, params1)), v);
1118-
}
1119-
Ok(ret)
1120-
}
1121-
1122-
type CacheKeyRaw = (
1123-
SmallVec<[Option<u32>; 2]>,
1124-
(SmallVec<[ParameterKey; 3]>, SmallVec<[ParameterKey; 3]>),
1125-
);
1126-
1127-
/// This newtype wraps a f64 to make it hashable so we can cache parameterized gates
1128-
/// based on the parameter value (assuming it's a float angle). However, Rust doesn't do
1129-
/// this by default and there are edge cases to track around it's usage. The biggest one
1130-
/// is this does not work with f64::NAN, f64::INFINITY, or f64::NEG_INFINITY
1131-
/// If you try to use these values with this type they will not work as expected.
1132-
/// This should only be used with the cache hashmap's keys and not used beyond that.
1133-
#[derive(Debug, Copy, Clone, PartialEq, FromPyObject)]
1134-
struct ParameterKey(f64);
1135-
1136-
impl ParameterKey {
1137-
fn key(&self) -> u64 {
1138-
// If we get a -0 the to_bits() return is not equivalent to 0
1139-
// because -0 has the sign bit set we'd be hashing 9223372036854775808
1140-
// and be storing it separately from 0. So this normalizes all 0s to
1141-
// be represented by 0
1142-
if self.0 == 0. { 0 } else { self.0.to_bits() }
1143-
}
1144-
}
1145-
1146-
impl std::hash::Hash for ParameterKey {
1147-
fn hash<H>(&self, state: &mut H)
1148-
where
1149-
H: std::hash::Hasher,
1150-
{
1151-
self.key().hash(state)
1152-
}
1153-
}
1154-
1155-
impl Eq for ParameterKey {}
1156-
1157-
fn hashable_params(params: &[Param]) -> Result<SmallVec<[ParameterKey; 3]>, CommutationError> {
1158-
params
1159-
.iter()
1160-
.map(|x| {
1161-
if let Param::Float(x) = x {
1162-
// NaN and Infinity (negative or positive) are not valid
1163-
// parameter values and our hacks to store parameters in
1164-
// the cache HashMap don't take these into account. So return
1165-
// an error to Python if we encounter these values.
1166-
if x.is_nan() || x.is_infinite() {
1167-
Err(CommutationError::HashingNaN)
1168-
} else {
1169-
Ok(ParameterKey(*x))
1170-
}
1171-
} else {
1172-
Err(CommutationError::HashingParameter)
1173-
}
1174-
})
1175-
.collect()
1176-
}
1177-
1178972
#[pyfunction]
1179973
pub fn get_standard_commutation_checker() -> CommutationChecker {
1180974
let library = standard_gates_commutations::get_commutation_library();
1181975
CommutationChecker {
1182976
library,
1183-
cache_max_entries: 1_000_000,
1184-
cache: HashMap::new(),
1185-
current_cache_entries: 0,
1186977
gates: None,
1187978
}
1188979
}

qiskit/circuit/commutation_checker.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from qiskit.circuit.operation import Operation
2020
from qiskit.circuit import Qubit
2121
from qiskit._accelerate.commutation_checker import CommutationChecker as RustChecker
22+
from qiskit.utils import deprecate_arg, deprecate_func
2223

2324

2425
class CommutationChecker:
@@ -45,18 +46,18 @@ class CommutationChecker:
4546
gates with free parameters (such as :class:`.RXGate` with a :class:`.ParameterExpression` as
4647
angle). Otherwise, a matrix-based check is performed, where two operations are said to
4748
commute, if the average gate fidelity of performing the commutation is above a certain threshold
48-
(see ``approximation_degree``). The result of this commutation is then added to the
49-
cached lookup table.
49+
(see ``approximation_degree``).
5050
"""
5151

52+
@deprecate_arg("cache_max_entries", since="2.5.0", removal_timeline="in Qiskit 3.0")
5253
def __init__(
5354
self,
5455
standard_gate_commutations: dict | None = None,
5556
cache_max_entries: int = 10**6,
5657
*,
5758
gates: set[str] | None = None,
5859
):
59-
self.cc = RustChecker(standard_gate_commutations, cache_max_entries, gates)
60+
self.cc = RustChecker(standard_gate_commutations, gates)
6061

6162
def commute_nodes(
6263
self,
@@ -118,13 +119,21 @@ def commute(
118119
matrix_max_num_qubits,
119120
)
120121

122+
@deprecate_func(since="2.5", removal_timeline="in Qiskit 3.0")
121123
def num_cached_entries(self):
122-
"""Returns number of cached entries"""
123-
return self.cc.num_cached_entries()
124+
"""Returns number of cached entries
124125
126+
This method will always return 0 because there is no longer an
127+
internal cache.
128+
"""
129+
return 0
130+
131+
@deprecate_func(since="2.5", removal_timeline="in Qiskit 3.0")
125132
def clear_cached_commutations(self):
126-
"""Clears the dictionary holding cached commutations"""
127-
self.cc.clear_cached_commutations()
133+
"""Clears the dictionary holding cached commutations
134+
135+
This method is a no-op as there is no longer an internal cache
136+
"""
128137

129138
def check_commutation_entries(
130139
self,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
deprecations_circuits:
3+
- The ``cache_max_entries`` argument on the the :class:`.CommutationChecker` class's
4+
constructor is deprecated and will be removed in Qiskit 3.0.0. This argument no longer has
5+
any effect because the :class:`.CommutationChecker` no longer maintains an internal cache
6+
of commutation relationships between gates as it is no longer necessary.
7+
- The :meth:`.CommutationChecker.clear_cached_commutations` method is deprecated and will be
8+
removed in Qiskit 3.0.0. This method no longer has any effect because the
9+
internal cache was removed from the :class:`.CommutationChecker` class as
10+
it was no longer necessary so there is nothing to clear.
11+
- The :meth:`.CommutationChecker.num_cached_entries` method is deprecated
12+
and will be removed in Qiskit 3.0.0. Since the removal of the internal
13+
cache from the :class:`.CommutationChecker` this method always returns 0
14+
because there are no internally cached entries in a :class:`.CommutationChecker`
15+
instance.

0 commit comments

Comments
 (0)