Skip to content

ebrahimpichka/colat-opt-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Collateral Optimization with JAX

GPU-accelerated LP solver for large-scale financial collateral optimization, using the Primal-Dual Hybrid Gradient (PDHG / Chambolle-Pock) algorithm implemented in JAX with scipy sparse matrix construction.


Problem

Alolocating collateral assets across counterparties to satisfy margin requirements while minimizing funding costs:

$$ \begin{align} \text{minimize} \quad & \sum_{i,j,k} \text{cost}[i,j,k] \cdot x[i,j,k] \\ \text{subject to} \quad & \sum_{i,k} \text{haircut}[j,i] \cdot x[i,j,k] \geq \text{margin}[j] \quad \text{(margin coverage)} \\ & \sum_j x[i,j,k] \leq \text{avail}[i,k] \quad \text{(asset availability)} \\ & \sum_{\substack{i \in \text{type}_t \ k}} \text{haircut}[j,i] \cdot x[i,j,k] \leq \alpha[j,t] \cdot \text{margin}[j] \quad \text{(concentration)} \\ & x[i,j,k] \geq 0 \end{align} $$

Variables: allocation of asset i to counterparty j from pool k. Scale: 100K to 10M+ variables for realistic institutional portfolios.


Why first-order methods

Interior-point solvers (Gurobi, HiGHS) scale as O(n^3) per iteration and require dense factorizations that do not fit in GPU memory for large problems. PDHG requires only sparse matrix-vector products — O(nnz) per iteration — and maps directly onto GPU parallelism.


Implementation highlights

Vectorized matrix construction (src/formulation/problem_builder.py)
All constraint blocks are built with np.meshgrid + ravel arithmetic. No Python loops: matrix construction for 1M-variable problems takes seconds, not hours.

Ruiz equilibration (src/solvers/preprocess.py)
Iteratively scales rows and columns of A so that each infinity norm converges to 1. Essential for financial data where haircuts (~0.98), market values (~1e7), and margin requirements (~1e8) span many orders of magnitude.

Relative convergence (src/solvers/pdhg.py)
Tolerances are scale-invariant:

primal: ||max(b - Ax, 0)|| / (1 + ||b||) < tol
dual  : ||max(A'y - c, 0)|| / (1 + ||c||) < tol

Gap-based restart
Every 300 iterations the normalised duality gap is compared to the best seen. If improvement stalls (gap > 0.99 * best_gap), momentum is cleared and the solver warm-starts from the best iterate found so far.

Infeasibility certificate
When the dual norm exceeds a threshold, the solver checks whether the normalised dual direction y_n satisfies A'y_n <= 0 and b'y_n > 0. If so, it returns this ray as a proof of primal infeasibility.

Operator norm upper bound
Step sizes are set via sigma = tau = 0.99 / L where L = sqrt(max_row_L1 * max_col_L1). This is an O(nnz) bound that avoids the cost of power iteration.


Sample Result

For a generated problem of the size:

n_assets=5_000, n_counterparties=500, n_pools=10

>> LP:
  25,000,500 variables | 52,500 constraints | 75,000,500 non-zeros (0.0057% fill)

Obtained results show:

Solving (GPU backend — set use_jax=False for CPU)...

=======================================================
  Status     : optimal
  Iterations : 900
  Objective  : 77,183,271.71
  Primal res : 4.47e-07
  Dual res   : 7.12e-07
  Solve time : 642.40s

Convergence plot:

Convergence plot - jax


Project structure

src/
  data/generator.py          synthetic data (assets, counterparties, pools)
  formulation/problem_builder.py  vectorized LP matrix construction
  solvers/preprocess.py      Ruiz equilibration
  solvers/pdhg.py            PDHG solver
  visualization/plots.py     convergence and allocation plots
tests/
  test_problem.py            matrix shape, RHS, sparsity, large-scale build
  test_solver.py             KKT conditions, linprog comparison, certificates
examples/
  basic_usage.py             end-to-end generate -> solve -> visualise

Quick start

uv sync              # create .venv and install dependencies from pyproject.toml
uv run python examples/basic_usage.py

For GPU: replace jax[cpu] with jax[cuda12] in pyproject.toml and set use_jax=True in the solver call. No other code changes are needed.

Install dev extras and run tests:

uv sync --extra dev
uv run pytest tests/ -v

Configuration

from src.solvers.pdhg import PDHGParams, PDHGSolver

params = PDHGParams(
    max_iter=10_000,
    tol=1e-6,           # relative tolerance
    check_interval=50,  # how often to evaluate residuals
    restart_interval=300,
    equilibrate=True,   # apply Ruiz before solving
    verbose=True,
)

solver = PDHGSolver(problem.A, problem.b, problem.c, params=params, use_jax=True)
result = solver.solve()

print(result.status)     # 'optimal' | 'max_iter' | 'primal_infeasible'
print(result.objective)
print(result.certificate)  # dual ray if primal_infeasible, else None

Algorithm

PDHG (Chambolle & Pock 2011) solves the saddle-point problem:

min_{x >= 0}  max_{y >= 0}  c'x + y'(b - Ax)

Updates per iteration:

y^{k+1} = max(0,  y^k + sigma * (b - A * x_bar^k))
x^{k+1} = max(0,  x^k - tau   * (c - A' * y^{k+1}))
x_bar^{k+1} = x^{k+1} + theta * (x^{k+1} - x^k)

Convergence is O(1/k) for LP. The gap-based restart accelerates practical convergence on ill-conditioned financial problems.


References

  • Chambolle, A. & Pock, T. (2011). A first-order primal-dual algorithm for convex problems.
  • Applegate, D. et al. (2021). Practical large-scale linear programming using PDLP.
  • Ruiz, D. (2001). A scaling algorithm to equilibrate both rows and columns norms in matrices.

About

Large-scale financial collateral optimization using the Primal-Dual Hybrid Gradient (PDHG / Chambolle-Pock) algorithm implemented in JAX

Topics

Resources

Stars

Watchers

Forks

Contributors

Languages