Skip to content

Commit d7033ab

Browse files
committed
Add scale-down debugging protocol and update compile overhead gotcha
- Add mandatory scale-down debugging protocol: always validate 1 GPU first, scale up incrementally, never debug at large scale - Update BL compile overhead gotcha to mark as under investigation (likely a bug in our branch, not inherent to all-gather GP)
1 parent 69a64c4 commit d7033ab

1 file changed

Lines changed: 37 additions & 0 deletions

File tree

CLAUDE.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
# CLAUDE.md
22
FAIRChem is Meta FAIR Chemistry's ML framework for atomistic simulations. Core abstractions: foundation models (UMA) with backbone+heads architecture, ASE calculator integration, Hydra-based config, and multi-task training via TorchTNT.
33

4+
## Scale-Down Debugging Protocol (MANDATORY)
5+
6+
**ALWAYS validate changes bottom-up: 1 GPU → 2 GPUs → 8 GPUs → multi-node.**
7+
8+
1. **Start at 1 GPU**: Run all tests and sanity checks against the baseline (main branch). Compare compile time, inference speed, and numerical outputs. If anything differs from main at 1 GPU, fix it before scaling up.
9+
2. **Scale incrementally**: Only after 1 GPU passes, move to 2 GPUs, then 8 GPUs (1 node), then multi-node.
10+
3. **If something breaks or doesn't make sense at large scale**: ALWAYS scale DOWN until you find the bug. NEVER debug at large scale — it wastes GPU-hours, has long queue times, and makes root-cause analysis impossible.
11+
4. **Baseline comparison at every scale**: At each GPU count, compare your branch against main with identical settings. Any regression in compile time, throughput, or correctness must be investigated at the smallest scale where it reproduces.
12+
13+
This applies to: compile time, inference throughput, numerical accuracy, memory usage, and any other observable behavior. No exceptions.
14+
415
## Development Commands
516

617
```bash
@@ -292,3 +303,29 @@ The `AllToAllCollect.forward()` uses `torch.autograd.Function.apply()`. Every ar
292303
293304
### Prefer all_to_all_single over all_to_all for packed tensors
294305
For NCCL, `dist.all_to_all_single(output, input, output_split_sizes, input_split_sizes, group)` operates on packed contiguous tensors directly, avoiding the Python list creation overhead from `tensor.split()` + `list()` that `dist.all_to_all(output_list, input_list, group)` requires. When send/recv data is already contiguous (as in AllToAllCollect), `all_to_all_single` is more efficient. However, it's not supported on Gloo — always provide a fallback. Also note that `split()` creates views into the original tensor, so after `all_to_all` fills the views, the original buffer already contains the data — a subsequent `torch.cat()` on the views is a redundant copy.
306+
307+
### Morton Z-order normalization must use global scale
308+
When computing Morton Z-order codes for spatial partitioning, normalize positions using a SINGLE global scale factor (the largest bounding-box extent), NOT per-dimension normalization. Per-dimension normalization amplifies noise in short dimensions — e.g., a 100-unit x-gap becomes indistinguishable from a 2-unit y-gap after independent rescaling, destroying spatial locality in the Morton curve. Use `extent = (pos.max(0)[0] - min_pos).max()` instead of per-dimension `extent = pos.max(0)[0] - min_pos`.
309+
310+
### All-gather baseline also has per-layer collectives
311+
The all-gather GP baseline has one all-gather per `Edgewise.forward()` layer (not just one upfront). UMA-S has 4 layers, so the baseline does 4 all-gathers per forward pass. The A2A path does 4 all-to-all_collect calls per forward (sending less data) plus a setup all_to_all for index exchange. The performance comparison is about setup overhead, not per-layer communication volume.
312+
313+
### Morton partition must use balanced splitting, not ceil-based chunking
314+
The `partition_atoms_spatial()` function must use `arange(N) * P // N` (balanced splitting) instead of `arange(N) // ceil(N/P)` (ceil-based chunking). The ceil-based formula leaves trailing ranks empty when N is not a multiple of P — e.g., 1000 atoms / 64 ranks gives rank 63 zero atoms because `ceil(1000/64) = 16` and `63 * 16 = 1008 > 1000`. This causes NCCL hangs because the rank crashes at the empty-partition assertion while other ranks block on the collective. The balanced formula `i * P // N` distributes atoms evenly: first N%P ranks get `ceil(N/P)`, rest get `floor(N/P)`.
315+
316+
### Communication is <1% of UMA-S forward pass time
317+
At 8 GPUs (intra-node NVLink), the all-gather per-layer communication takes ~0.015ms vs ~65-98ms total forward pass (0.02%). At 64 GPUs across 8 nodes, it's still only ~1ms vs ~65-91ms total (1.1%). The performance numbers in benchmark results are in **ns/day** (not QPS) — convert with `QPS = ns_per_day × 1e6 / 86400`. This means reducing communication volume alone cannot produce meaningful speedups; A2A must also reduce overhead in other areas (memory, synchronization, etc.) to be competitive.
318+
319+
### UMASFastPytorchBackend requires activation_checkpointing=False
320+
When using `execution_mode: "umas_fast_gpu"` (the default speed benchmark mode), `activation_checkpointing` must be `False`. Setting it to `True` raises `ValueError: UMASFastPytorchBackend requires activation_checkpointing=False`. For benchmarks with `compile=False`, just omit the activation_checkpointing override entirely (defaults to False).
321+
322+
### Always use inference presets (turbo/default), not manual settings
323+
Use the standard presets from `inference.py` — don't manually mix and match settings like `compile=True` with `tf32=False`. The presets are:
324+
- **turbo**: tf32=True, activation_checkpointing=False, merge_mole=True, compile=True, execution_mode=None (auto → umas_fast_gpu)
325+
- **default**: tf32=False, activation_checkpointing=True, merge_mole=False, compile=False, execution_mode=None (auto → general, since act_ckpt=True fails umas_fast_gpu validation)
326+
- **traineval**: tf32=False, activation_checkpointing=False, merge_mole=False, compile=False, internal_graph_gen_version=1
327+
328+
Via Hydra CLI, override individual fields (not by preset name): `runner.inference_settings.tf32=false runner.inference_settings.compile=false` etc. To set execution_mode to auto-detect: `runner.inference_settings.execution_mode=null`.
329+
330+
### BL torch.compile overhead — UNDER INVESTIGATION
331+
At 64 GPUs with all-gather GP on our branch, torch.compile took ~23 min per atom count (92 min total for 4 sizes). This is likely a BUG introduced by our code changes (e.g., `@torch.compiler.disable` on `_generate_graph`, additional branches in `Edgewise.forward`), NOT an inherent property of all-gather GP. Normal GP=64 compile on main takes <10 minutes. The 1-GPU comparison test will confirm whether our branch regressed compile time. Do NOT claim this as an A2A advantage until the root cause is confirmed.

0 commit comments

Comments
 (0)