Skip to content

Commit d61a76a

Browse files
committed
Implement training and evaluation orchestration for AlphaFold2
- Added test for training and evaluation metrics preference when both loaders are present. - Introduced training utilities for optimization, metrics, checkpointing, and runtime in `training/__init__.py`. - Developed checkpoint save and restore utilities in `training/checkpoints.py`, including model serialization and state tracking. - Created efficient metrics computation for RMSD, TM-score, and GDT-TS in `training/efficient_metrics.py`. - Implemented evaluation utilities for AlphaFold2-like runs in `training/eval_one_epoch.py`. - Added parallel training helpers for data, model, and hybrid execution modes in `training/train_parallel/__init__.py`. - Developed distributed and data-parallel helpers for multi-GPU training in `training/train_parallel/data_parallel.py`. - Created two-stage model-parallel wrappers for the AlphaFold2 model in `training/train_parallel/model_parallel.py`.
1 parent 9c8c9c3 commit d61a76a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1723
-151
lines changed

.github/release.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
changelog:
2+
exclude:
3+
labels:
4+
- skip-changelog
5+
categories:
6+
- title: Breaking Changes
7+
labels:
8+
- breaking-change
9+
- title: Architecture
10+
labels:
11+
- architecture
12+
- model
13+
- title: Training and Evaluation
14+
labels:
15+
- training
16+
- evaluation
17+
- title: Data Pipeline
18+
labels:
19+
- data
20+
- title: Tests and CI
21+
labels:
22+
- testing
23+
- ci
24+
- title: Documentation
25+
labels:
26+
- documentation
27+
- title: Other Changes
28+
labels:
29+
- '*'

.github/release_template.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
## Summary
2+
-
3+
4+
## Highlights
5+
-
6+
7+
## Validation
8+
-
9+
10+
## Notes
11+
-
12+
13+
## Known Limitations
14+
-

.github/workflows/ci.yml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
- master
8+
pull_request:
9+
10+
jobs:
11+
test:
12+
runs-on: ubuntu-latest
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
python-version: ["3.10", "3.11"]
17+
18+
steps:
19+
- name: Check out repository
20+
uses: actions/checkout@v4
21+
22+
- name: Set up Python
23+
uses: actions/setup-python@v5
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
cache: pip
27+
28+
- name: Install package and test dependencies
29+
run: |
30+
python -m pip install --upgrade pip
31+
python -m pip install -e '.[dev,data]'
32+
33+
- name: Run CPU-safe module and data tests
34+
run: |
35+
python -m pytest -q \
36+
tests/test_module_integrity.py \
37+
tests/test_ipa.py \
38+
tests/test_row_column_attention.py \
39+
tests/test_triangle_attention.py \
40+
tests/test_triangle_multiplication.py \
41+
tests/test_opm.py \
42+
tests/test_extra_msa_stack.py \
43+
tests/test_template_stack.py \
44+
tests/test_metrics.py \
45+
tests/test_loader_wrappers.py \
46+
tests/test_showcase_loader.py \
47+
tests/test_train_eval_orchestration.py \
48+
tests/test_eval_one_epoch.py

.gitignore

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@ __pycache__/
33
.venv/
44
*.pyc
55
*.pyo
6+
*.pyd
7+
*.egg-info/
8+
build/
9+
dist/
10+
.coverage
11+
htmlcov/
12+
.mypy_cache/
13+
.ruff_cache/
614
data/af_subset/
715
checkpoints*/
16+
artifacts/
817
.venv_tmp_data
918
.venv_tmp_crop
1019
.agents
1120
.skills-lock
1221

13-
14-

CONTRIBUTING.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Create an isolated environment and install dependencies:
3232
python3 -m venv .venv
3333
source .venv/bin/activate
3434
pip install -r requirements.txt
35+
36+
# Optional editable install for local CLI entry points
37+
pip install -e '.[dev,data]'
3538
```
3639

3740
If you work from Conda, use the equivalent environment setup and install the same requirements.

Dockerfile

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
1212
git \
1313
&& rm -rf /var/lib/apt/lists/*
1414

15-
COPY requirements.txt /tmp/requirements.txt
16-
RUN pip install --no-cache-dir --upgrade pip && \
17-
pip install --no-cache-dir -r /tmp/requirements.txt
18-
1915
COPY . /app
2016

17+
RUN pip install --no-cache-dir --upgrade pip && \
18+
pip install --no-cache-dir -e ".[data]"
19+
2120
CMD ["bash"]

README.md

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
[![Python](https://img.shields.io/badge/Python-3.10%2B-blue.svg)](#installation)
1212
[![PyTorch](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg)](#installation)
13+
[![CI](https://github.com/pablo-reyes8/alpha-fold2/actions/workflows/ci.yml/badge.svg)](https://github.com/pablo-reyes8/alpha-fold2/actions/workflows/ci.yml)
1314
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](#license)
1415
[![Status](https://img.shields.io/badge/status-Research%20Prototype-orange)](#project-status)
1516

@@ -70,7 +71,7 @@ To make experimentation easier to reproduce, the repository follows a **manifest
7071
- **Config-Driven Experiments:** Main settings such as model size, depth, learning rate, and EMA can be adjusted through YAML files.
7172
- **Feature-Rich Loader:** The current dataloader returns sequence/MSA tensors plus `extra_msa_feat`, `extra_msa_mask`, `template_angle_feat`, `template_pair_feat`, and `template_mask` when those artifacts are present in the Foldbench assets.
7273
- **Data Inspection Utilities:** Provides simple CLI tools to inspect manifests, preview A3M files, and visualize CA distance maps before training.
73-
- **Notebook-Friendly Workflow:** The main walkthrough notebook is [Alpha_Fold_English.ipynb](notebooks/Alpha_Fold_English.ipynb), and a local training-focused version is available in [notebooks\train_model_setup_examples.ipynb](notebooks/train_model_local.ipynb).
74+
- **Notebook-Friendly Workflow:** The main walkthrough notebook is [Alpha_Fold_English.ipynb](notebooks/Alpha_Fold_English.ipynb), and a local training-focused walkthrough is available in [train_model_setup_examples.ipynb](notebooks/train_model_setup_examples.ipynb).
7475

7576
---
7677

@@ -84,7 +85,8 @@ To make experimentation easier to reproduce, the repository follows a **manifest
8485
├── data/ # manifest-based data pipeline plus a tiny bundled showcase subset
8586
│ ├── download_data.sh
8687
│ ├── foldbench.py
87-
│ ├── preproces_data.py
88+
│ ├── preprocess_data.py
89+
│ ├── loader_wrappers.py
8890
│ ├── dataloaders.py
8991
│ ├── collate_proteins.py
9092
│ ├── visualize_data.py
@@ -94,7 +96,7 @@ To make experimentation easier to reproduce, the repository follows a **manifest
9496
│ └── losses/
9597
├── training/ # single-device training loop, ablation registry, AMP, EMA, checkpoints, and metrics
9698
│ ├── ablations/ # predefined architecture and loss ablation presets
97-
│ └── train_paralel/ # DDP and model-parallel helpers
99+
│ └── train_parallel/ # DDP and model-parallel helpers
98100
├── scripts/ # operational CLIs for data prep, validation, and training
99101
│ ├── prepare_data.py
100102
│ ├── inspect_data.py
@@ -108,6 +110,7 @@ To make experimentation easier to reproduce, the repository follows a **manifest
108110
├── notebooks/ # interactive experiments for Colab or local exploration
109111
├── paper/ # reference material from the AlphaFold paper and notes
110112
├── assets/ # README visuals and showcase media
113+
├── pyproject.toml
111114
├── requirements.txt
112115
├── Dockerfile
113116
└── README.md
@@ -116,7 +119,8 @@ To make experimentation easier to reproduce, the repository follows a **manifest
116119
### Key files
117120

118121
- [data/download_data.sh](data/download_data.sh) — downloads the Foldbench subset from a target list or CSV input.
119-
- [data/preproces_data.py](data/preproces_data.py) — rebuilds manifests, normalizes local paths, and emits YAML summaries.
122+
- [data/preprocess_data.py](data/preprocess_data.py) — rebuilds manifests, normalizes local paths, and emits YAML summaries.
123+
- [data/loader_wrappers.py](data/loader_wrappers.py) — convenience builders for plain dataloaders and deterministic train/eval splits over one dataset.
120124
- [data/dataloaders.py](data/dataloaders.py) — dataset layer that maps manifests, mmCIF structures, MSA files, and torsion targets into tensors.
121125
- [scripts/prepare_data.py](scripts/prepare_data.py) — high-level CLI for downloading data, refreshing manifests, and smoke-testing loaders.
122126
- [model/alphafold2.py](model/alphafold2.py) — top-level AlphaFold2-like model that wires embeddings, Evoformer, structure, recycling, and heads.
@@ -125,11 +129,12 @@ To make experimentation easier to reproduce, the repository follows a **manifest
125129
- [model/alphafold2_full_loss.py](model/alphafold2_full_loss.py) — full training loss orchestrator combining FAPE, distogram, pLDDT, and torsion supervision.
126130
- [model/losses/](model/losses/) — component losses and helpers for geometry-aware supervision.
127131
- [training/train_one_epoch.py](training/train_one_epoch.py) — per-epoch optimization routine with AMP, recycling, logging, and metric collection.
132+
- [training/eval_one_epoch.py](training/eval_one_epoch.py) — evaluation loop that mirrors training-time logging without optimizer steps.
128133
- [training/train_alphafold2.py](training/train_alphafold2.py) — full training orchestrator for checkpointing, resume, monitoring, and epoch scheduling.
129134
- [training/ablations/catalog.py](training/ablations/catalog.py) — registry of prebuilt architecture and loss ablations resolved on top of a base experiment config.
130135
- [training/ablations/runtime.py](training/ablations/runtime.py) — resolves baseline or named ablations into a safe config variant without changing the default training path.
131-
- [training/train_paralel/data_parallel.py](training/train_paralel/data_parallel.py) — DDP utilities, distributed samplers, and rank synchronization helpers.
132-
- [training/train_paralel/model_parallel.py](training/train_paralel/model_parallel.py) — two-stage model-parallel wrapper for splitting AlphaFold2 across GPUs.
136+
- [training/train_parallel/data_parallel.py](training/train_parallel/data_parallel.py) — DDP utilities, distributed samplers, and rank synchronization helpers.
137+
- [training/train_parallel/model_parallel.py](training/train_parallel/model_parallel.py) — two-stage model-parallel wrapper for splitting AlphaFold2 across GPUs.
133138
- [scripts/train_model.py](scripts/train_model.py) — standard config-driven single-device training launcher.
134139
- [scripts/train_parallel.py](scripts/train_parallel.py) — multi-GPU launcher for DDP, model parallelism, and hybrid setups.
135140
- [scripts/train_ablation.py](scripts/train_ablation.py) — single-device launcher for named architecture and loss ablations.
@@ -150,6 +155,9 @@ The repository includes a tiny downloaded test subset under [data/af_subset_show
150155
python3 -m venv .venv
151156
source .venv/bin/activate
152157
pip install -r requirements.txt
158+
159+
# Editable install with package metadata and CLI entry points
160+
pip install -e '.[dev,data]'
153161
```
154162

155163
### 2) Download the subset
@@ -167,7 +175,7 @@ python3 scripts/prepare_data.py download --targets-csv data/Proteinas_secuencias
167175
### 3) Rebuild the manifest with local paths
168176

169177
```bash
170-
python3 -m data.preproces_data \
178+
python3 -m data.preprocess_data \
171179
--config config/data/foldbench_subset.yaml \
172180
--json-path data/af_subset/jsons/fb_protein.json \
173181
--msa-root data/af_subset/foldbench_msas \
@@ -203,7 +211,7 @@ dataset = FoldbenchProteinDataset(manifest_csv="data/Proteinas_secuencias.csv")
203211

204212
### Minimal Python setup
205213

206-
The full notebook [notebooks/train_model_local.ipynb](notebooks/train_model_local.ipynb) exposes many knobs, but the smallest useful training setup looks like this:
214+
The full notebook [notebooks/train_model_setup_examples.ipynb](notebooks/train_model_setup_examples.ipynb) exposes many knobs, but the smallest useful training setup looks like this:
207215

208216
```python
209217
import torch
@@ -443,7 +451,7 @@ Low-VRAM preset for Colab-class GPUs in the `15-20 GB` range, using a reduced tr
443451

444452
This file is a **reference document**, not a statement that the current code already consumes every field end-to-end.
445453

446-
Its role is to provide a structured target for future extension and to document the broader AlphaFold/OpenFold design space.
454+
Its role is to provide a structured target for future extension and to document the broader AlphaFold/OpenFold design space. It also includes a `current_repo_alignment` section that maps the nested reference schema to the flat config fields consumed by the current codebase.
447455

448456
---
449457

config/data/foldbench_subset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ metadata:
44
provenance:
55
source_manifest_csv: data/Proteinas_secuencias.csv
66
source_note: The checked-in CSV was generated in Colab and stores /content paths.
7-
refresh_note: Re-run data.preproces_data locally to rewrite paths for your machine.
7+
refresh_note: Re-run data.preprocess_data locally to rewrite paths for your machine.
88

99
paths:
1010
dataset_root: data/af_subset

config/experiments/af2_low_vram.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ data:
2020
loader:
2121
batch_size: 1
2222
shuffle: true
23+
eval_size: 1
24+
eval_shuffle: false
25+
split_seed: 42
2326
num_workers: 0
2427
pin_memory: false
2528

@@ -112,6 +115,7 @@ trainer:
112115
run_name: af2_low_vram
113116
save_every: 1
114117
save_last: true
118+
eval_every: 1
115119

116120
geometry:
117121
ideal_backbone_local:

config/experiments/alphafold2_full_reference.yaml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,56 @@ metadata:
1010
- Includes template, extra MSA and auxiliary loss settings that the current repo does not fully consume yet.
1111
- Training defaults below reflect OpenFold reference settings when explicitly defined.
1212

13+
current_repo_alignment:
14+
purpose: Map the nested AlphaFold/OpenFold reference schema to the flat runnable configs under config/experiments/af2_*.yaml.
15+
consumed_directly_by_code: false
16+
flat_config_equivalents:
17+
data.max_msa_seqs: data.initial_training.max_msa_clusters
18+
data.max_extra_msa_seqs: data.common.max_extra_msa
19+
data.max_templates: globals.max_templates
20+
data.crop_size: data.initial_training.crop_size
21+
model.c_m: globals.c_m
22+
model.c_z: globals.c_z
23+
model.c_s: globals.c_s
24+
model.max_relpos: globals.max_relative_feature
25+
model.num_evoformer_blocks: model.evoformer.no_blocks
26+
model.num_structure_blocks: model.structure_module.no_blocks
27+
model.recycle_min_bin: model.recycling_embedder.min_bin
28+
model.recycle_max_bin: model.recycling_embedder.max_bin
29+
model.recycle_dist_bins: model.recycling_embedder.num_bins
30+
model.extra_msa_stack_enabled: model.extra_msa.enabled
31+
model.extra_msa_dim: model.extra_msa.c_in
32+
model.extra_msa_c_e: model.extra_msa.c_out
33+
model.extra_msa_num_blocks: model.extra_msa.no_blocks
34+
model.template_stack_enabled: model.template.enabled
35+
model.template_c_t: globals.c_t
36+
model.template_num_blocks: model.template.pair_stack.no_blocks
37+
model.dist_bins: heads.distogram.num_bins
38+
model.plddt_bins: heads.plddt.num_bins
39+
loss.dist_num_bins: heads.distogram.num_bins
40+
loss.dist_min_bin: heads.distogram.min_bin
41+
loss.dist_max_bin: heads.distogram.max_bin
42+
loss.plddt_num_bins: heads.plddt.num_bins
43+
loss.plddt_inclusion_radius: heads.plddt.cutoff
44+
current_support:
45+
implemented:
46+
- Evoformer trunk
47+
- extra MSA stack
48+
- template conditioning
49+
- recycling embedder
50+
- IPA-based structure module
51+
- distogram, pLDDT, and torsion heads/losses
52+
partial:
53+
- input feature pipeline
54+
- template retrieval pipeline
55+
- structure-module hyperparameter surface
56+
not_yet_implemented:
57+
- masked MSA objective
58+
- experimentally resolved head
59+
- violation loss
60+
- TM head
61+
- all-atom and side-chain reconstruction
62+
1363
globals:
1464
c_m: 256
1565
c_z: 128

0 commit comments

Comments
 (0)