Skip to content

Commit 8511733

Browse files
Reduce memory usage in mk_exf_conditions by processing variables one at a time in chunks
Previously all variables were merged into one dataset and `.values` was called on the full time series, loading ~29 GB per variable into RAM simultaneously. Now each variable is opened, written in ~1-month time chunks, and closed before the next variable is processed, keeping peak memory to a few hundred MB per chunk.
1 parent 11344c5 commit 8511733

1 file changed

Lines changed: 80 additions & 24 deletions

File tree

spectre_utils/mk_exf_conditions.py

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,107 @@
11
import os
2+
import sys
23
from spectre_utils import common
34
import yaml
45
from metpy.calc import specific_humidity_from_dewpoint
56
from metpy.units import units
67
from datetime import datetime
8+
import numpy as np
9+
import xarray as xr
10+
11+
# Number of time steps to load and write at once.
12+
# ERA5 is hourly, so 744 ≈ one month. Tune down if memory is still tight.
13+
TIME_CHUNK = 744
14+
15+
16+
def _open_var(working_directory, prefix, mitgcm_name, years, t1, t2):
17+
"""Open a single ERA5 variable across all years as a lazily-chunked dataset."""
18+
files = [f"{working_directory}/{prefix}_{mitgcm_name}_{year}.nc" for year in years]
19+
for fp in files:
20+
if not os.path.exists(fp):
21+
print(f"Missing file: {fp}", file=sys.stderr)
22+
sys.exit(1)
23+
ds = xr.open_mfdataset(files, combine="by_coords", chunks={"valid_time": TIME_CHUNK}).sel(
24+
valid_time=slice(t1, t2)
25+
)
26+
data_vars = list(ds.data_vars)
27+
if len(data_vars) == 1 and data_vars[0] != mitgcm_name:
28+
ds = ds.rename({data_vars[0]: mitgcm_name})
29+
return ds
30+
31+
32+
def _write_chunked(da, output_path):
33+
"""Write a DataArray to a big-endian float32 binary file in time chunks."""
34+
n_times = da.sizes["valid_time"]
35+
with open(output_path, "wb") as f:
36+
for i in range(0, n_times, TIME_CHUNK):
37+
chunk = da.isel(valid_time=slice(i, i + TIME_CHUNK)).values
38+
chunk.astype(">f4").tofile(f)
39+
740

841
def main():
942

1043
args = common.cli()
1144

1245
# Load configuration from YAML file
13-
with open(args.config_file, 'r') as f:
46+
with open(args.config_file, "r") as f:
1447
config = yaml.safe_load(f)
1548

16-
working_directory = config['working_directory']
17-
simulation_directory = config['simulation_directory']
18-
years = config['atmosphere']['years']
19-
atm_vars = config['atmosphere']['variables']
20-
computed_vars = config['atmosphere'].get('computed_variables', [])
21-
prefix = config['atmosphere']['prefix']
22-
simulation_input_dir = os.path.join(simulation_directory, 'input')
23-
24-
t1 = datetime.strptime(config['domain']['time']['start'], "%Y-%m-%d")
25-
t2 = datetime.strptime(config['domain']['time']['end'], "%Y-%m-%d")
26-
27-
ds = common.load_atm_dataset(working_directory, prefix, years, atm_vars, t1, t2)
28-
print(ds)
49+
working_directory = config["working_directory"]
50+
simulation_directory = config["simulation_directory"]
51+
years = config["atmosphere"]["years"]
52+
atm_vars = config["atmosphere"]["variables"]
53+
computed_vars = config["atmosphere"].get("computed_variables", [])
54+
prefix = config["atmosphere"]["prefix"]
55+
simulation_input_dir = os.path.join(simulation_directory, "input")
2956

30-
# Compute specific humidity from dewpoint temperature and surface pressure
31-
d2m_celsius = ds["d2m"] - 273.15 # Convert from Kelvin to Celsius
32-
ds['aqh'] = specific_humidity_from_dewpoint(ds['sp'] * units.Pa, d2m_celsius * units.degC)
57+
t1 = datetime.strptime(config["domain"]["time"]["start"], "%Y-%m-%d")
58+
t2 = datetime.strptime(config["domain"]["time"]["end"], "%Y-%m-%d")
3359

34-
# Write all configured variables to binary files
60+
# Process and write each variable one at a time to limit peak memory usage.
61+
# Each variable's dataset is opened, written in TIME_CHUNK slices, then closed
62+
# before the next variable is loaded.
3563
written = set()
3664
for var in atm_vars:
3765
mitgcm_name = var["mitgcm_name"]
3866
if mitgcm_name in written:
3967
continue
4068
written.add(mitgcm_name)
41-
with open(os.path.join(simulation_input_dir, f'{mitgcm_name}.bin'), 'wb') as f:
42-
ds[mitgcm_name].values.astype('>f4').tofile(f)
4369

44-
# Write computed variables
70+
print(f"Processing {mitgcm_name}...")
71+
scale_factor = var.get("scale_factor")
72+
73+
ds = _open_var(working_directory, prefix, mitgcm_name, years, t1, t2)
74+
da = ds[mitgcm_name]
75+
if scale_factor is not None:
76+
da = da * scale_factor
77+
78+
output_path = os.path.join(simulation_input_dir, f"{mitgcm_name}.bin")
79+
_write_chunked(da, output_path)
80+
ds.close()
81+
82+
# Compute derived variables (e.g. specific humidity from dewpoint + surface pressure).
83+
# Only the two inputs are loaded at a time, also written in chunks.
4584
for cv in computed_vars:
4685
mitgcm_name = cv["mitgcm_name"]
47-
with open(os.path.join(simulation_input_dir, f'{mitgcm_name}.bin'), 'wb') as f:
48-
ds[mitgcm_name].values.astype('>f4').tofile(f)
49-
86+
print(f"Computing {mitgcm_name}...")
87+
88+
ds_d2m = _open_var(working_directory, prefix, "d2m", years, t1, t2)
89+
ds_sp = _open_var(working_directory, prefix, "sp", years, t1, t2)
90+
n_times = ds_d2m.sizes["valid_time"]
91+
92+
output_path = os.path.join(simulation_input_dir, f"{mitgcm_name}.bin")
93+
with open(output_path, "wb") as f:
94+
for i in range(0, n_times, TIME_CHUNK):
95+
d2m_k = ds_d2m["d2m"].isel(valid_time=slice(i, i + TIME_CHUNK)).values
96+
sp_pa = ds_sp["sp"].isel(valid_time=slice(i, i + TIME_CHUNK)).values
97+
aqh = specific_humidity_from_dewpoint(
98+
sp_pa * units.Pa, (d2m_k - 273.15) * units.degC
99+
)
100+
np.array(aqh).astype(">f4").tofile(f)
101+
102+
ds_d2m.close()
103+
ds_sp.close()
104+
105+
50106
if __name__ == "__main__":
51107
main()

0 commit comments

Comments
 (0)