Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ hashbrown.version = "0.15.5"
num-bigint = "0.4"
num-complex = "0.4"
nalgebra = "0.33"
faer = "0.24"
numpy = "0.28"
ndarray = "0.16"
smallvec = "1.15"
Expand Down
2 changes: 1 addition & 1 deletion crates/synthesis/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ num-traits.workspace = true
rand.workspace = true
rand_pcg.workspace = true
rand_distr.workspace = true
faer = "0.24"
faer.workspace = true
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need this change anymore but it doesn't hurt either, since we've agreed at this point we'll rely on faer for all the heavy linear algebra. So there is a fairly good chance we'll use it in another crate at some point.

rustworkx-core.workspace = true
rustiq-core = "0.0.11"
rsgridsynth = "0.2.0"
Expand Down
113 changes: 111 additions & 2 deletions crates/synthesis/src/linalg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,23 @@
// that they have been altered from the originals.

use approx::{abs_diff_eq, relative_ne};
use faer::Mat;
use faer::MatRef;
use nalgebra::{DMatrix, DMatrixView, Dim, Dyn, MatrixView, Scalar, ViewStorage};
use ndarray::ArrayView2;
use ndarray::ShapeBuilder;
use num_complex::Complex64;

use crate::qsd::QSDError;

pub mod cos_sin_decomp;

const ATOL_DEFAULT: f64 = 1e-8;
const RTOL_DEFAULT: f64 = 1e-5;

/// Tolerance used for debug checking
pub const VERIFY_TOL: f64 = 1e-7;

#[inline]
pub fn nalgebra_array_view<T: Scalar, R: Dim, C: Dim>(mat: MatrixView<T, R, C>) -> ArrayView2<T> {
let dim = ndarray::Dim(mat.shape());
Expand Down Expand Up @@ -151,7 +157,7 @@ fn verify_svd_decomp(
w: DMatrixView<Complex64>,
) -> bool {
let mat_check = v * s * w;
abs_diff_eq!(mat, mat_check.as_view(), epsilon = 1e-7)
abs_diff_eq!(mat, mat_check.as_view(), epsilon = VERIFY_TOL)
}

/// Verifies the given matrix U is unitary by comparing U*U to the identity matrix
Expand All @@ -161,7 +167,7 @@ pub fn verify_unitary(u: &DMatrix<Complex64>) -> bool {
let id_mat = DMatrix::identity(n, n);
let uu = u.adjoint() * u;

abs_diff_eq!(uu, id_mat, epsilon = 1e-7)
abs_diff_eq!(uu, id_mat, epsilon = VERIFY_TOL)
}

/// Given a matrix that is "close" to unitary, returns the closest
Expand Down Expand Up @@ -205,6 +211,7 @@ pub fn svd_decomposition(
) -> (DMatrix<Complex64>, DMatrix<Complex64>, DMatrix<Complex64>) {
let mat_view: DMatrixView<Complex64> = mat.as_view();
let faer_mat: MatRef<Complex64> = nalgebra_to_faer(mat_view);

let faer_svd = faer_mat.svd().unwrap();

let u_faer = faer_svd.U();
Expand Down Expand Up @@ -232,6 +239,108 @@ pub fn svd_decomposition(
(u_na.into(), s_na, v_na)
}

/// Result for the singular value decomposition (SVD) of a matrix `A`.
///
/// The decomposition is given by three matrices `(U, S, V)` such that `A = U * S * V^\dagger`.
pub struct SVDResult {
pub u: Mat<Complex64>,
pub s: Mat<Complex64>,
pub v: Mat<Complex64>,
}

/// Runs singular valued decomposition on `mat`.
pub fn svd_decomposition_faer(mat: MatRef<Complex64>) -> Result<SVDResult, QSDError> {
Comment thread
alexanderivrii marked this conversation as resolved.
Outdated
let svd = mat.svd().map_err(|_| QSDError::SVDDecompositionFailed)?;

let n = mat.nrows();
let u = svd.U().to_owned();
let v = svd.V().to_owned();
let mut s = Mat::zeros(n, n);
s.diagonal_mut().copy_from(svd.S());

let svd_result = SVDResult { u, s, v };
debug_assert!(verify_svd_decomposition_faer(mat.as_ref(), &svd_result));

Ok(svd_result)
}

/// Computes the eigenvalues and the eigenvectors of a square matrix
pub fn eigendecomposition_faer(
mat: MatRef<Complex64>,
) -> Result<(Vec<Complex64>, Mat<Complex64>), QSDError> {
Comment thread
alexanderivrii marked this conversation as resolved.
Outdated
let eigh = mat
.eigen()
.map_err(|_| QSDError::EigenDecompositionFailed)?;

let vmat = eigh.U().to_owned();
// unfortunately, we need to call closest_unitary_faer here
let vmat = closest_unitary_faer(vmat.as_ref())?;
let eigvals: Vec<Complex64> = eigh.S().column_vector().iter().cloned().collect();
Comment thread
alexanderivrii marked this conversation as resolved.
Outdated
Ok((eigvals, vmat))
}

pub fn closest_unitary_faer(mat: MatRef<Complex64>) -> Result<Mat<Complex64>, QSDError> {
Comment thread
alexanderivrii marked this conversation as resolved.
Outdated
let svd = mat.svd().map_err(|_| QSDError::SVDDecompositionFailed)?;
Ok(svd.U() * svd.V().adjoint())
}

pub fn from_diagonal_faer(diag: &[Complex64]) -> Mat<Complex64> {
let n = diag.len();
let mut mat = Mat::zeros(n, n);
mat.diagonal_mut()
.column_vector_mut()
.iter_mut()
.zip(diag)
.for_each(|(x, y)| *x = *y);
mat
}

/// Returns a block matrix `[a, b; c, d]`.
/// The matrices `a`, `b`, `c`, `d` are all assumed to be square matrices of the same size
pub fn block_matrix_faer(
a: MatRef<Complex64>,
b: MatRef<Complex64>,
c: MatRef<Complex64>,
d: MatRef<Complex64>,
) -> Mat<Complex64> {
let n = a.nrows();
let mut block_matrix = Mat::<Complex64>::zeros(2 * n, 2 * n);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to initialize this with zero? It feels unnecessary because you should be populating every element in the matrix based on the blocks.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only way I can see how to avoid this (keeping the current code structure) is doing something like

 let mut block_matrix = Mat::<Complex64>::with_capacity(2 * n, 2 * n);
 unsafe { block_matrix.set_dims(2 * n, 2 * n) };

However, I semi-expect that the compiler should be able to optimize this itself. I have kept the original code for now, but replaced the loops as per Shelly's suggestion. I am also not very worried about this specific code as for now it's only used in debug assertions.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something along those lines. I wasn't sure if faer's Mat type had a method for creating an uninitialized matrix directly or you had to do something like that. But it would be inside an unsafe either way. The other option would be to use from_fn and have conditions on the index ranges.

I find it unlikely that the compiler will able to optimize it away because Mat::zeros() is calling Mat::from_fn which is explicitly calling the closure that returns zero for every element in the array. But it's only a small performance difference, it's something we can explore later if this becomes a bottleneck.

block_matrix.as_mut().submatrix_mut(0, 0, n, n).copy_from(a);
block_matrix.as_mut().submatrix_mut(0, n, n, n).copy_from(b);
block_matrix.as_mut().submatrix_mut(n, 0, n, n).copy_from(c);
block_matrix.as_mut().submatrix_mut(n, n, n, n).copy_from(d);
block_matrix
}

/// Verify SVD decomposition gives the same unitary
fn verify_svd_decomposition_faer(mat: MatRef<Complex64>, svd: &SVDResult) -> bool {
let mat_check = svd.u.as_ref() * svd.s.as_ref() * svd.v.as_ref().adjoint();
(mat - mat_check).norm_max() < VERIFY_TOL
}

/// Verifies the given matrix U is unitary by comparing U*U to the identity matrix
pub fn verify_unitary_faer(u: MatRef<Complex64>) -> bool {
let n = u.shape().0;

let id_mat = Mat::<Complex64>::identity(n, n);
let uu = u.adjoint() * u;

(uu.as_ref() - id_mat.as_ref()).norm_max() < VERIFY_TOL
}

// check whether a matrix is zero (up to tolerance)
pub fn is_zero_matrix_faer(mat: MatRef<Complex64>, atol: Option<f64>) -> bool {
let atol = atol.unwrap_or(1e-12);
for i in 0..mat.nrows() {
for j in 0..mat.ncols() {
if !abs_diff_eq!(mat[(i, j)], Complex64::ZERO, epsilon = atol) {
return false;
}
}
}
true
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
Loading