Skip to content

Commit f8bd5d0

Browse files
Reimplementing QSD using faer instead of nalgebra. (#15874)
* Reimplementing QSD using faer instead of nalgebra. * Addressing review comments * refacting SVDResult into a struct * for naming consistency, changing SVDResult to return V rather than V.adjoint * refactoring ZXZResult into a struct * improved diagonal matrix construction * changing qs_decomposition API to work with ArrayView2 * using ndarray_to_faer * remove aux function zxz_decomp_svd * updating minimum tol to 1e-12 * cloned -> copied * Replacing QSDError in mod.rs by LinAlgError * remove pub from ZXZResult * improving user-facing error messages
1 parent 2fb5be6 commit f8bd5d0

8 files changed

Lines changed: 294 additions & 149 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ hashbrown.version = "0.15.5"
2121
num-bigint = "0.4"
2222
num-complex = "0.4"
2323
nalgebra = "0.33"
24+
faer = "0.24"
2425
numpy = "0.28"
2526
ndarray = "0.16"
2627
smallvec = "1.15"

crates/synthesis/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ num-traits.workspace = true
2727
rand.workspace = true
2828
rand_pcg.workspace = true
2929
rand_distr.workspace = true
30-
faer = "0.24"
30+
faer.workspace = true
3131
rustworkx-core.workspace = true
3232
rustiq-core = "0.0.11"
3333
rsgridsynth = "0.2.0"

crates/synthesis/src/linalg/mod.rs

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,51 @@
1111
// that they have been altered from the originals.
1212

1313
use approx::{abs_diff_eq, relative_ne};
14+
use faer::Mat;
1415
use faer::MatRef;
1516
use nalgebra::{DMatrix, DMatrixView, Dim, Dyn, MatrixView, Scalar, ViewStorage};
1617
use ndarray::ArrayView2;
1718
use ndarray::ShapeBuilder;
1819
use num_complex::Complex64;
20+
use pyo3::PyErr;
21+
use thiserror::Error;
22+
23+
use crate::QiskitError;
1924

2025
pub mod cos_sin_decomp;
2126

2227
const ATOL_DEFAULT: f64 = 1e-8;
2328
const RTOL_DEFAULT: f64 = 1e-5;
2429

30+
/// Tolerance used for debug checking
31+
pub const VERIFY_TOL: f64 = 1e-7;
32+
33+
/// Errors that might occur in linear algebra computations
34+
#[derive(Error, Debug)]
35+
pub enum LinAlgError {
36+
#[error("Eigen decomposition failed")]
37+
EigenDecompositionFailed,
38+
39+
#[error("SVD decomposition failed")]
40+
SVDDecompositionFailed,
41+
}
42+
43+
impl From<LinAlgError> for PyErr {
44+
fn from(error: LinAlgError) -> Self {
45+
match error {
46+
LinAlgError::EigenDecompositionFailed => QiskitError::new_err(
47+
"Internal eigendecomposition failed. \
48+
This can point to a numerical tolerance issue.",
49+
),
50+
51+
LinAlgError::SVDDecompositionFailed => QiskitError::new_err(
52+
"Internal SVD decomposition failed. \
53+
This can point to a numerical tolerance issue.",
54+
),
55+
}
56+
}
57+
}
58+
2559
#[inline]
2660
pub fn nalgebra_array_view<T: Scalar, R: Dim, C: Dim>(mat: MatrixView<T, R, C>) -> ArrayView2<T> {
2761
let dim = ndarray::Dim(mat.shape());
@@ -151,7 +185,7 @@ fn verify_svd_decomp(
151185
w: DMatrixView<Complex64>,
152186
) -> bool {
153187
let mat_check = v * s * w;
154-
abs_diff_eq!(mat, mat_check.as_view(), epsilon = 1e-7)
188+
abs_diff_eq!(mat, mat_check.as_view(), epsilon = VERIFY_TOL)
155189
}
156190

157191
/// Verifies the given matrix U is unitary by comparing U*U to the identity matrix
@@ -161,7 +195,7 @@ pub fn verify_unitary(u: &DMatrix<Complex64>) -> bool {
161195
let id_mat = DMatrix::identity(n, n);
162196
let uu = u.adjoint() * u;
163197

164-
abs_diff_eq!(uu, id_mat, epsilon = 1e-7)
198+
abs_diff_eq!(uu, id_mat, epsilon = VERIFY_TOL)
165199
}
166200

167201
/// Given a matrix that is "close" to unitary, returns the closest
@@ -205,6 +239,7 @@ pub fn svd_decomposition(
205239
) -> (DMatrix<Complex64>, DMatrix<Complex64>, DMatrix<Complex64>) {
206240
let mat_view: DMatrixView<Complex64> = mat.as_view();
207241
let faer_mat: MatRef<Complex64> = nalgebra_to_faer(mat_view);
242+
208243
let faer_svd = faer_mat.svd().unwrap();
209244

210245
let u_faer = faer_svd.U();
@@ -232,6 +267,108 @@ pub fn svd_decomposition(
232267
(u_na.into(), s_na, v_na)
233268
}
234269

270+
/// Result for the singular value decomposition (SVD) of a matrix `A`.
271+
///
272+
/// The decomposition is given by three matrices `(U, S, V)` such that `A = U * S * V^\dagger`.
273+
pub struct SVDResult {
274+
pub u: Mat<Complex64>,
275+
pub s: Mat<Complex64>,
276+
pub v: Mat<Complex64>,
277+
}
278+
279+
/// Runs singular valued decomposition on `mat`.
280+
pub fn svd_decomposition_faer(mat: MatRef<Complex64>) -> Result<SVDResult, LinAlgError> {
281+
let svd = mat.svd().map_err(|_| LinAlgError::SVDDecompositionFailed)?;
282+
283+
let n = mat.nrows();
284+
let u = svd.U().to_owned();
285+
let v = svd.V().to_owned();
286+
let mut s = Mat::zeros(n, n);
287+
s.diagonal_mut().copy_from(svd.S());
288+
289+
let svd_result = SVDResult { u, s, v };
290+
debug_assert!(verify_svd_decomposition_faer(mat.as_ref(), &svd_result));
291+
292+
Ok(svd_result)
293+
}
294+
295+
/// Computes the eigenvalues and the eigenvectors of a square matrix
296+
pub fn eigendecomposition_faer(
297+
mat: MatRef<Complex64>,
298+
) -> Result<(Vec<Complex64>, Mat<Complex64>), LinAlgError> {
299+
let eigh = mat
300+
.eigen()
301+
.map_err(|_| LinAlgError::EigenDecompositionFailed)?;
302+
303+
let vmat = eigh.U().to_owned();
304+
// unfortunately, we need to call closest_unitary_faer here
305+
let vmat = closest_unitary_faer(vmat.as_ref())?;
306+
let eigvals: Vec<Complex64> = eigh.S().column_vector().iter().copied().collect();
307+
Ok((eigvals, vmat))
308+
}
309+
310+
pub fn closest_unitary_faer(mat: MatRef<Complex64>) -> Result<Mat<Complex64>, LinAlgError> {
311+
let svd = mat.svd().map_err(|_| LinAlgError::SVDDecompositionFailed)?;
312+
Ok(svd.U() * svd.V().adjoint())
313+
}
314+
315+
pub fn from_diagonal_faer(diag: &[Complex64]) -> Mat<Complex64> {
316+
let n = diag.len();
317+
let mut mat = Mat::zeros(n, n);
318+
mat.diagonal_mut()
319+
.column_vector_mut()
320+
.iter_mut()
321+
.zip(diag)
322+
.for_each(|(x, y)| *x = *y);
323+
mat
324+
}
325+
326+
/// Returns a block matrix `[a, b; c, d]`.
327+
/// The matrices `a`, `b`, `c`, `d` are all assumed to be square matrices of the same size
328+
pub fn block_matrix_faer(
329+
a: MatRef<Complex64>,
330+
b: MatRef<Complex64>,
331+
c: MatRef<Complex64>,
332+
d: MatRef<Complex64>,
333+
) -> Mat<Complex64> {
334+
let n = a.nrows();
335+
let mut block_matrix = Mat::<Complex64>::zeros(2 * n, 2 * n);
336+
block_matrix.as_mut().submatrix_mut(0, 0, n, n).copy_from(a);
337+
block_matrix.as_mut().submatrix_mut(0, n, n, n).copy_from(b);
338+
block_matrix.as_mut().submatrix_mut(n, 0, n, n).copy_from(c);
339+
block_matrix.as_mut().submatrix_mut(n, n, n, n).copy_from(d);
340+
block_matrix
341+
}
342+
343+
/// Verify SVD decomposition gives the same unitary
344+
fn verify_svd_decomposition_faer(mat: MatRef<Complex64>, svd: &SVDResult) -> bool {
345+
let mat_check = svd.u.as_ref() * svd.s.as_ref() * svd.v.as_ref().adjoint();
346+
(mat - mat_check).norm_max() < VERIFY_TOL
347+
}
348+
349+
/// Verifies the given matrix U is unitary by comparing U*U to the identity matrix
350+
pub fn verify_unitary_faer(u: MatRef<Complex64>) -> bool {
351+
let n = u.shape().0;
352+
353+
let id_mat = Mat::<Complex64>::identity(n, n);
354+
let uu = u.adjoint() * u;
355+
356+
(uu.as_ref() - id_mat.as_ref()).norm_max() < VERIFY_TOL
357+
}
358+
359+
// check whether a matrix is zero (up to tolerance)
360+
pub fn is_zero_matrix_faer(mat: MatRef<Complex64>, atol: Option<f64>) -> bool {
361+
let atol = atol.unwrap_or(1e-12);
362+
for i in 0..mat.nrows() {
363+
for j in 0..mat.ncols() {
364+
if !abs_diff_eq!(mat[(i, j)], Complex64::ZERO, epsilon = atol) {
365+
return false;
366+
}
367+
}
368+
}
369+
true
370+
}
371+
235372
#[cfg(test)]
236373
mod test {
237374
use super::*;

0 commit comments

Comments
 (0)