Skip to content

Commit b00ed19

Browse files
Fix OOM in review_exf_conditions by computing stats and histograms in time chunks
Previously _var_stats and _make_histograms materialised the full time series as float64/float32 arrays (~30-60 GB per variable), and the non-negative and dewpoint checks created full boolean arrays of similar size. All four operations now iterate over TIME_CHUNK (744) time steps at a time, keeping peak memory to a few hundred MB per variable.
1 parent 0346f7f commit b00ed19

4 files changed

Lines changed: 641 additions & 17 deletions

File tree

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Animate ocean boundary condition (OBC) binary files.
2+
3+
For each (variable, boundary) pair produces one MP4:
4+
- 3D variables (U, V, T, S): pcolormesh of depth × boundary-position over time
5+
- Eta: line plot of boundary-position over time
6+
7+
Output goes to <simulation_directory>/animations/ocean_bcs/.
8+
"""
9+
10+
import os
11+
import yaml
12+
import numpy as np
13+
import xarray as xr
14+
import matplotlib
15+
matplotlib.use("Agg")
16+
import matplotlib.pyplot as plt
17+
from matplotlib.animation import FFMpegWriter
18+
19+
from spectre_utils import common
20+
21+
VARS = ["U", "V", "T", "S", "Eta"]
22+
BOUNDARIES = ["south", "north", "west", "east"]
23+
24+
PHYSICAL_BOUNDS = {
25+
"U": (-3.0, 3.0),
26+
"V": (-3.0, 3.0),
27+
"T": (-2.0, 35.0),
28+
"S": ( 0.0, 42.0),
29+
"Eta": (-3.0, 3.0),
30+
}
31+
32+
UNITS = {"U": "m/s", "V": "m/s", "T": "°C", "S": "PSU", "Eta": "m"}
33+
34+
CMAPS = {
35+
"U": "RdBu_r", "V": "RdBu_r", "T": "plasma", "S": "viridis", "Eta": "RdBu_r"
36+
}
37+
38+
39+
def animate_boundary(da, var, boundary, out_path, fps=4, dpi=100):
40+
"""Animate one (var, boundary) DataArray as an MP4."""
41+
pos_dim = [d for d in da.dims if d not in ("time", "depth")][0]
42+
pos = da.coords[pos_dim].values
43+
times = da.coords["time"].values
44+
nt = da.sizes["time"]
45+
xlabel = "Longitude" if pos_dim == "lon" else "Latitude"
46+
47+
lo, hi = PHYSICAL_BOUNDS.get(var, (None, None))
48+
cmap = CMAPS.get(var, "viridis")
49+
unit = UNITS.get(var, "")
50+
51+
is_3d = da.ndim == 3
52+
53+
if is_3d:
54+
fig, ax = plt.subplots(figsize=(10, 5), dpi=dpi)
55+
depth = da.coords["depth"].values
56+
frame0 = da.isel(time=0).values.astype(np.float32)
57+
pcm = ax.pcolormesh(pos, depth, frame0, shading="auto",
58+
vmin=lo, vmax=hi, cmap=cmap)
59+
plt.colorbar(pcm, ax=ax, label=f"{var} [{unit}]")
60+
ax.invert_yaxis()
61+
ax.set_xlabel(xlabel)
62+
ax.set_ylabel("Depth (m)")
63+
64+
def update(i):
65+
pcm.set_array(da.isel(time=i).values.astype(np.float32))
66+
67+
else:
68+
fig, ax = plt.subplots(figsize=(10, 3), dpi=dpi)
69+
frame0 = da.isel(time=0).values.astype(np.float32)
70+
(line,) = ax.plot(pos, frame0, color="steelblue", lw=0.8)
71+
if lo is not None:
72+
ax.set_ylim(lo - abs(lo) * 0.1, hi + abs(hi) * 0.1)
73+
ax.set_xlabel(xlabel)
74+
ax.set_ylabel(f"{var} [{unit}]")
75+
ax.grid(True, alpha=0.3)
76+
77+
def update(i):
78+
line.set_ydata(da.isel(time=i).values.astype(np.float32))
79+
80+
title = ax.set_title(
81+
f"{var}{boundary} | {np.datetime_as_string(times[0], unit='D')}"
82+
)
83+
84+
writer = FFMpegWriter(fps=fps, bitrate=-1)
85+
with writer.saving(fig, out_path, dpi):
86+
writer.grab_frame()
87+
for i in range(1, nt):
88+
update(i)
89+
title.set_text(
90+
f"{var}{boundary} | {np.datetime_as_string(times[i], unit='D')}"
91+
)
92+
writer.grab_frame()
93+
94+
plt.close(fig)
95+
print(f" Saved {out_path}")
96+
97+
98+
def main():
99+
args = common.cli()
100+
101+
with open(args.config_file, "r") as f:
102+
config = yaml.safe_load(f)
103+
104+
simulation_directory = config["simulation_directory"]
105+
working_directory = config["working_directory"]
106+
simulation_input_dir = os.path.join(simulation_directory, "input")
107+
108+
animations_dir = os.path.join(simulation_directory, "animations", "ocean_bcs")
109+
os.makedirs(animations_dir, exist_ok=True)
110+
111+
print("Loading OBC binary files...")
112+
data = common.load_obc_binaries(simulation_input_dir, working_directory, config)
113+
114+
for var in VARS:
115+
for bnd in BOUNDARIES:
116+
key = (var, bnd)
117+
if key not in data:
118+
print(f" Skipping {var}.{bnd} (not found)")
119+
continue
120+
out_path = os.path.join(animations_dir, f"{var}_{bnd}.mp4")
121+
print(f"Animating {var}.{bnd} -> {out_path}")
122+
animate_boundary(data[key], var, bnd, out_path, fps=4)
123+
124+
print("Done.")
125+
126+
127+
if __name__ == "__main__":
128+
main()

spectre_utils/common.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,103 @@ def load_exf_binaries(simulation_input_dir, var_names, working_directory, prefix
116116
return xr.Dataset(data_vars)
117117

118118

119+
def load_obc_binaries(simulation_input_dir, working_directory, config):
120+
"""Load ocean boundary condition binary files into a dict of labelled DataArrays.
121+
122+
Returns a mapping ``{(var, boundary): xr.DataArray}`` where:
123+
- ``var`` ∈ {"U", "V", "T", "S", "Eta"}
124+
- ``boundary`` ∈ {"south", "north", "west", "east"}
125+
126+
Coordinates (time, depth, and boundary-parallel position as longitude or
127+
latitude) are extracted from the raw glorys12 NetCDF files. The binary
128+
data are memory-mapped so only accessed portions are read into RAM.
129+
"""
130+
import xarray as xr
131+
import numpy as np
132+
import os
133+
import glob as _glob
134+
135+
i0 = config.get("domain", {}).get("longitude", {}).get("start", 2)
136+
i1 = config.get("domain", {}).get("longitude", {}).get("end", -2)
137+
j0 = config.get("domain", {}).get("latitude", {}).get("start", 1)
138+
j1 = config.get("domain", {}).get("latitude", {}).get("end", -1)
139+
prefix = config.get("ocean", {}).get("prefix", "glorysv12")
140+
141+
t_files = sorted(_glob.glob(f"{working_directory}/{prefix}_T_glorys12_raw.*.nc"))
142+
if not t_files:
143+
raise FileNotFoundError(f"No glorys12 T files found in {working_directory}")
144+
145+
# One file for static coords (depth, nav_lon, nav_lat, grid size).
146+
one_ds = xr.open_dataset(t_files[0])
147+
nz = one_ds.sizes["deptht"]
148+
ny = one_ds.sizes["y"]
149+
nx = one_ds.sizes["x"]
150+
depth = one_ds["deptht"].values
151+
nav_lon = one_ds["nav_lon"].values # (ny, nx)
152+
nav_lat = one_ds["nav_lat"].values # (ny, nx)
153+
one_ds.close()
154+
155+
# Full time axis from all T files (lazy — no data read).
156+
all_ds = xr.open_mfdataset(t_files, combine="by_coords")
157+
times = all_ds["time_counter"].values
158+
all_ds.close()
159+
nt = len(times)
160+
161+
# Resolve negative indices to absolute positions.
162+
i1_abs = nx + i1 if i1 < 0 else i1 # = nx - 2
163+
j1_abs = ny + j1 if j1 < 0 else j1 # = ny - 1
164+
165+
# 1-D geographic coordinates along each boundary.
166+
# South/north: position varies in x (longitude-like).
167+
lon_sn = nav_lon[j0, i0:i1_abs] # V, T, S, Eta
168+
lon_sn_u = nav_lon[j0, i0:i1_abs - 1] # U (narrower by 1; C-grid stagger)
169+
lat_sn = nav_lat[j0, i0:i1_abs]
170+
lat_sn_u = nav_lat[j0, i0:i1_abs - 1]
171+
# West/east: position varies in y (latitude-like).
172+
lon_we = nav_lon[j0:j1_abs, i0]
173+
lat_we = nav_lat[j0:j1_abs, i0]
174+
175+
# (shape, pos_dim_name, pos_coord) keyed by (var, boundary).
176+
_specs = {}
177+
for bnd in ("south", "north"):
178+
_specs[("U", bnd)] = ((nt, nz, len(lon_sn_u)), "lon", lon_sn_u)
179+
_specs[("V", bnd)] = ((nt, nz, len(lon_sn)), "lon", lon_sn)
180+
_specs[("T", bnd)] = ((nt, nz, len(lon_sn)), "lon", lon_sn)
181+
_specs[("S", bnd)] = ((nt, nz, len(lon_sn)), "lon", lon_sn)
182+
_specs[("Eta", bnd)] = ((nt, len(lon_sn)), "lon", lon_sn)
183+
for bnd in ("west", "east"):
184+
_specs[("U", bnd)] = ((nt, nz, len(lat_we)), "lat", lat_we)
185+
_specs[("V", bnd)] = ((nt, nz, len(lat_we)), "lat", lat_we)
186+
_specs[("T", bnd)] = ((nt, nz, len(lat_we)), "lat", lat_we)
187+
_specs[("S", bnd)] = ((nt, nz, len(lat_we)), "lat", lat_we)
188+
_specs[("Eta", bnd)] = ((nt, len(lat_we)), "lat", lat_we)
189+
190+
result = {}
191+
for (var, bnd), (shape, pos_name, pos_coord) in _specs.items():
192+
bin_path = os.path.join(simulation_input_dir, f"{var}.{bnd}.bin")
193+
if not os.path.exists(bin_path):
194+
print(f"Warning: binary file not found, skipping: {bin_path}", file=sys.stderr)
195+
continue
196+
arr = np.memmap(bin_path, dtype=">f4", mode="r", shape=shape)
197+
if len(shape) == 3: # time × depth × position
198+
da = xr.DataArray(
199+
arr,
200+
dims=["time", "depth", pos_name],
201+
coords={"time": times, "depth": depth, pos_name: pos_coord},
202+
name=f"{var}_{bnd}",
203+
)
204+
else: # time × position (Eta)
205+
da = xr.DataArray(
206+
arr,
207+
dims=["time", pos_name],
208+
coords={"time": times, pos_name: pos_coord},
209+
name=f"{var}_{bnd}",
210+
)
211+
result[(var, bnd)] = da
212+
213+
return result
214+
215+
119216
def load_atm_dataset(working_directory, prefix, years, atm_vars, t1, t2):
120217
"""Load ERA5 atmospheric variables per MITgcm name, rename, and apply optional scale factors.
121218

spectre_utils/review_exf_conditions.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,44 @@
4646
# Variables that should be non-negative after any scale factor
4747
NON_NEGATIVE = {"swdown", "lwdown", "precip", "runoff", "aqh"}
4848

49+
# Number of time steps processed at once for memory-intensive operations.
50+
# ERA5 is hourly so 744 ≈ one month.
51+
TIME_CHUNK = 744
52+
4953

5054
# ---------------------------------------------------------------------------
5155
# Helpers
5256
# ---------------------------------------------------------------------------
5357

5458
def _var_stats(da: xr.DataArray) -> dict:
55-
"""Compute global min/mean/max and count of non-finite values (lazy-safe)."""
56-
arr = da.values.astype(np.float64)
57-
n_bad = int(np.sum(~np.isfinite(arr)))
58-
valid = arr[np.isfinite(arr)]
59-
if valid.size == 0:
60-
return dict(min=np.nan, mean=np.nan, max=np.nan, n_bad=n_bad, n=arr.size)
59+
"""Compute global min/mean/max and count of non-finite values in time chunks."""
60+
nt = da.sizes["valid_time"]
61+
n_bad = 0
62+
n_total = 0
63+
gmin = np.inf
64+
gmax = -np.inf
65+
gsum = 0.0
66+
67+
for i in range(0, nt, TIME_CHUNK):
68+
chunk = da.isel(valid_time=slice(i, i + TIME_CHUNK)).values.astype(np.float64)
69+
n_total += chunk.size
70+
finite = np.isfinite(chunk)
71+
n_bad += int((~finite).sum())
72+
valid = chunk[finite]
73+
if valid.size > 0:
74+
gmin = min(gmin, float(valid.min()))
75+
gmax = max(gmax, float(valid.max()))
76+
gsum += float(valid.sum())
77+
78+
n_valid = n_total - n_bad
79+
if n_valid == 0:
80+
return dict(min=np.nan, mean=np.nan, max=np.nan, n_bad=n_bad, n=n_total)
6181
return dict(
62-
min=float(valid.min()),
63-
mean=float(valid.mean()),
64-
max=float(valid.max()),
82+
min=gmin if np.isfinite(gmin) else np.nan,
83+
mean=gsum / n_valid,
84+
max=gmax if np.isfinite(gmax) else np.nan,
6585
n_bad=n_bad,
66-
n=int(arr.size),
86+
n=n_total,
6787
)
6888

6989

@@ -166,9 +186,32 @@ def _make_histograms(ds: xr.Dataset, var_names: list[str], out_path: str) -> Non
166186

167187
for idx, name in enumerate(present):
168188
ax = axes[idx // ncols][idx % ncols]
169-
vals = ds[name].values.ravel()
170-
vals = vals[np.isfinite(vals)]
171-
ax.hist(vals, bins=120, color="steelblue", edgecolor="none", density=True)
189+
190+
# Determine bin range from known physical bounds or the first time chunk.
191+
if name in PHYSICAL_BOUNDS:
192+
lo, hi = PHYSICAL_BOUNDS[name]
193+
else:
194+
first = ds[name].isel(valid_time=slice(0, TIME_CHUNK)).values.ravel()
195+
first = first[np.isfinite(first)]
196+
lo = float(first.min()) if first.size > 0 else 0.0
197+
hi = float(first.max()) if first.size > 0 else 1.0
198+
199+
bins = np.linspace(lo, hi, 121)
200+
counts = np.zeros(120, dtype=np.float64)
201+
202+
nt = ds[name].sizes["valid_time"]
203+
for i in range(0, nt, TIME_CHUNK):
204+
chunk = ds[name].isel(valid_time=slice(i, i + TIME_CHUNK)).values.ravel()
205+
chunk = chunk[np.isfinite(chunk)]
206+
c, _ = np.histogram(chunk, bins=bins)
207+
counts += c
208+
209+
bin_centers = (bins[:-1] + bins[1:]) / 2
210+
bin_width = bins[1] - bins[0]
211+
total = counts.sum()
212+
density = counts / (total * bin_width) if total > 0 else counts
213+
ax.bar(bin_centers, density, width=bin_width, color="steelblue", edgecolor="none")
214+
172215
var_units = ds[name].attrs.get("units", "")
173216
ax.set_xlabel(f"{name} [{var_units}]" if var_units else name, fontsize=8)
174217
ax.set_ylabel("Density", fontsize=8)
@@ -264,7 +307,8 @@ def main():
264307
stat_rows.append((name, "—", "—", "—", "—", "—"))
265308
continue
266309

267-
st = _var_stats(ds[name])
310+
da = ds[name]
311+
st = _var_stats(da)
268312
checks: list[tuple[bool, str]] = []
269313

270314
# 1. NaN / Inf
@@ -290,7 +334,11 @@ def main():
290334

291335
# 3. Non-negative where required
292336
if name in NON_NEGATIVE:
293-
n_neg = int((ds[name] < 0).sum().compute().item())
337+
n_neg = 0
338+
_nt = da.sizes["valid_time"]
339+
for _i in range(0, _nt, TIME_CHUNK):
340+
_chunk = da.isel(valid_time=slice(_i, _i + TIME_CHUNK)).values
341+
n_neg += int((_chunk < 0).sum())
294342
if n_neg == 0:
295343
checks.append((True, "All values ≥ 0"))
296344
else:
@@ -331,8 +379,14 @@ def main():
331379

332380
# Dewpoint ≤ air temperature
333381
if "d2m" in ds and "atemp" in ds:
334-
n_viol = int((ds["d2m"] > ds["atemp"]).sum().compute().item())
335-
n_total = int(ds["d2m"].size)
382+
n_viol = 0
383+
n_total = 0
384+
nt = ds["d2m"].sizes["valid_time"]
385+
for _i in range(0, nt, TIME_CHUNK):
386+
_d2m = ds["d2m"].isel(valid_time=slice(_i, _i + TIME_CHUNK)).values
387+
_atemp = ds["atemp"].isel(valid_time=slice(_i, _i + TIME_CHUNK)).values
388+
n_viol += int((_d2m > _atemp).sum())
389+
n_total += _d2m.size
336390
ok = n_viol == 0
337391
msg = (
338392
"Dewpoint ≤ air temperature everywhere"

0 commit comments

Comments
 (0)