Skip to content

Commit 3a37e4a

Browse files
authored
Fix erros in cupy Q_bar and Q_bar_N construction
1 parent 6a7c61d commit 3a37e4a

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

pimpc/solver.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,19 @@ def _solve_gpu(model, x0, u0, yref, uref, w, *, warm_vars=None, verbose=False):
292292
E_inv = cp.ones(nx_bar, dtype=dtype)
293293

294294
C_part = C_bar[:, :nx]
295-
Q_bar = cp.block(
295+
Q11 = C_part.T @ cp.asarray(model.Wy, dtype=dtype) @ C_part
296+
Q22 = cp.asarray(model.Wu, dtype=dtype)
297+
Q_bar = cp.vstack(
296298
[
297-
[C_part.T @ cp.asarray(model.Wy, dtype=dtype) @ C_part, cp.zeros((nx, nu), dtype=dtype)],
298-
[cp.zeros((nu, nx), dtype=dtype), cp.asarray(model.Wu, dtype=dtype)],
299+
cp.hstack([Q11, cp.zeros((nx, nu), dtype=dtype)]),
300+
cp.hstack([cp.zeros((nu, nx), dtype=dtype), Q22]),
299301
]
300302
)
301-
Q_bar_N = cp.block(
303+
Q11_N = C_part.T @ cp.asarray(model.Wf, dtype=dtype) @ C_part
304+
Q_bar_N = cp.vstack(
302305
[
303-
[C_part.T @ cp.asarray(model.Wf, dtype=dtype) @ C_part, cp.zeros((nx, nu), dtype=dtype)],
304-
[cp.zeros((nu, nx), dtype=dtype), cp.asarray(model.Wu, dtype=dtype)],
306+
cp.hstack([Q11_N, cp.zeros((nx, nu), dtype=dtype)]),
307+
cp.hstack([cp.zeros((nu, nx), dtype=dtype), Q22]),
305308
]
306309
)
307310
q_bar = cp.concatenate(

0 commit comments

Comments
 (0)