Skip to content

Commit 29b4128

Browse files
committed
Add TM head and logging for predicted TM score in AlphaFold2 model
- Introduced TM head with configurable parameters in model configurations. - Updated AlphaFold2 class to include TM logits and predicted TM score. - Enhanced training and evaluation scripts to log predicted TM metrics. - Modified YAML configuration files to support new TM head settings. - Added tests to validate TM head functionality and logging behavior.
1 parent 1e44e08 commit 29b4128

File tree

13 files changed

+352
-69
lines changed

13 files changed

+352
-69
lines changed

config/experiments/af2_canonical.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ model:
6767
masked_msa_num_classes: 23
6868
masked_msa_head_enabled: true
6969
plddt_head_enabled: true
70+
tm_num_bins: 64
71+
tm_max_error: 31.5
72+
tm_head_enabled: false
7073
torsion_head_enabled: true
7174

7275
loss:

config/experiments/af2_low_vram.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ model:
7070
masked_msa_num_classes: 23
7171
masked_msa_head_enabled: true
7272
plddt_head_enabled: true
73+
tm_num_bins: 64
74+
tm_max_error: 31.5
75+
tm_head_enabled: false
7376
torsion_head_enabled: true
7477

7578
loss:

config/experiments/af2_poc.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ model:
6464
masked_msa_num_classes: 23
6565
masked_msa_head_enabled: true
6666
plddt_head_enabled: true
67+
tm_num_bins: 64
68+
tm_max_error: 31.5
69+
tm_head_enabled: false
6770
torsion_head_enabled: true
6871

6972
loss:

model/alphafold2.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class AlphaFold2(nn.Module):
3535
- torsion angles
3636
- pLDDT
3737
- distogram logits
38+
- optional TM logits and pTM
3839
"""
3940
@staticmethod
4041
def _normalize_ablation_id(ablation):
@@ -72,6 +73,7 @@ def resolve_ablation_defaults(cls, ablation):
7273
"distogram_head_enabled": False,
7374
"masked_msa_head_enabled": False,
7475
"plddt_head_enabled": False,
76+
"tm_head_enabled": False,
7577
"torsion_head_enabled": False,
7678
},
7779
4: {
@@ -105,6 +107,8 @@ def __init__(
105107
dist_bins=64,
106108
masked_msa_num_classes=23,
107109
plddt_bins=50,
110+
tm_num_bins=64,
111+
tm_max_error=31.5,
108112
n_torsions=7,
109113
num_res_blocks_torsion=2,
110114
recycle_min_bin=3.25,
@@ -132,6 +136,7 @@ def __init__(
132136
distogram_head_enabled=True,
133137
masked_msa_head_enabled=True,
134138
plddt_head_enabled=True,
139+
tm_head_enabled=False,
135140
torsion_head_enabled=True):
136141

137142
super().__init__()
@@ -164,6 +169,7 @@ def __init__(
164169
distogram_head_enabled = ablation_defaults.get("distogram_head_enabled", distogram_head_enabled)
165170
masked_msa_head_enabled = ablation_defaults.get("masked_msa_head_enabled", masked_msa_head_enabled)
166171
plddt_head_enabled = ablation_defaults.get("plddt_head_enabled", plddt_head_enabled)
172+
tm_head_enabled = ablation_defaults.get("tm_head_enabled", tm_head_enabled)
167173
torsion_head_enabled = ablation_defaults.get("torsion_head_enabled", torsion_head_enabled)
168174

169175
self.ablation = self._normalize_ablation_id(ablation)
@@ -185,6 +191,7 @@ def __init__(
185191
self.distogram_head_enabled = bool(distogram_head_enabled)
186192
self.masked_msa_head_enabled = bool(masked_msa_head_enabled)
187193
self.plddt_head_enabled = bool(plddt_head_enabled)
194+
self.tm_head_enabled = bool(tm_head_enabled)
188195
self.torsion_head_enabled = bool(torsion_head_enabled)
189196

190197

@@ -222,6 +229,7 @@ def __init__(
222229
self.plddt_head = PlddtHead(c_s=c_s, num_bins=plddt_bins)
223230
self.distogram_head = DistogramHead(c_z=c_z, num_bins=dist_bins)
224231
self.masked_msa_head = MaskedMsaHead(c_m=c_m, num_classes=masked_msa_num_classes)
232+
self.tm_head = TMHead(c_z=c_z, num_bins=tm_num_bins, max_error=tm_max_error)
225233
self.torsion_head = TorsionHead(c_s=c_s, n_torsions=n_torsions , num_res_blocks = num_res_blocks_torsion)
226234
self.recycling_embedder = RecyclingEmbedder(
227235
c_m=c_m,
@@ -251,6 +259,7 @@ def __init__(
251259
zero_init_linear(self.plddt_head.mlp[-1])
252260
zero_init_linear(self.distogram_head.linear)
253261
zero_init_linear(self.masked_msa_head.linear)
262+
zero_init_linear(self.tm_head.linear)
254263

255264
self._freeze_module(self.evoformer, enabled=self.evoformer_enabled)
256265
self._freeze_module(self.extra_msa_stack, enabled=self.extra_msa_stack_enabled)
@@ -261,6 +270,7 @@ def __init__(
261270
self._freeze_module(self.distogram_head, enabled=self.distogram_head_enabled)
262271
self._freeze_module(self.masked_msa_head, enabled=self.masked_msa_head_enabled)
263272
self._freeze_module(self.plddt_head, enabled=self.plddt_head_enabled)
273+
self._freeze_module(self.tm_head, enabled=self.tm_head_enabled)
264274
self._freeze_module(self.torsion_head, enabled=self.torsion_head_enabled)
265275

266276
@staticmethod
@@ -420,6 +430,10 @@ def forward(
420430
masked_msa_logits = None
421431
if self.masked_msa_head_enabled:
422432
masked_msa_logits = self.masked_msa_head(m[:, :original_msa_depth])
433+
if self.tm_head_enabled:
434+
tm_logits, ptm = self.tm_head(z, residue_mask=seq_mask)
435+
else:
436+
tm_logits, ptm = None, None
423437

424438
# single repr + structure
425439
s0 = self.single_proj(m)
@@ -484,6 +498,8 @@ def forward(
484498
"plddt": plddt,
485499
"distogram_logits": distogram_logits,
486500
"masked_msa_logits": masked_msa_logits,
501+
"tm_logits": tm_logits,
502+
"ptm": ptm,
487503
}
488504

489505
if recycle_idx < num_recycles:

model/alphafold2_heads.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Prediction heads built on top of the shared AlphaFold representations.
22
33
The classes in this module project internal sequence or pair features into the
4-
single representation, pLDDT logits, and distogram logits used by the model
5-
output dictionary and downstream loss computation.
4+
single representation, pLDDT logits, distogram logits, masked-MSA logits, and
5+
an optional predicted-TM head used for confidence-style reporting.
66
"""
77

88
import torch
@@ -75,3 +75,92 @@ def __init__(self, c_m=256, num_classes=23):
7575
def forward(self, m):
7676
logits = self.linear(self.ln(m))
7777
return logits
78+
79+
80+
def compute_predicted_tm_score(
81+
tm_logits: torch.Tensor,
82+
*,
83+
residue_mask: torch.Tensor | None = None,
84+
bin_centers: torch.Tensor | None = None,
85+
eps: float = 1e-8,
86+
) -> torch.Tensor:
87+
"""AlphaFold pTM lower bound from pairwise error logits.
88+
89+
Parameters
90+
----------
91+
tm_logits : [B, L, L, num_bins]
92+
Logits over aligned-error bins derived from the final pair representation.
93+
residue_mask : [B, L], optional
94+
Valid residues to include in the domain / chain subset.
95+
bin_centers : [num_bins], optional
96+
Representative error values for each aligned-error bin.
97+
eps : float
98+
Small numerical constant.
99+
"""
100+
101+
if tm_logits.ndim != 4:
102+
raise ValueError(f"tm_logits must have shape [B, L, L, C], got {tuple(tm_logits.shape)}")
103+
104+
batch_size, length, _, num_bins = tm_logits.shape
105+
if bin_centers is None:
106+
if num_bins <= 1:
107+
bin_width = 0.5
108+
else:
109+
bin_width = 31.5 / float(num_bins - 1)
110+
bin_centers = torch.arange(num_bins, device=tm_logits.device, dtype=tm_logits.dtype)
111+
bin_centers = bin_width * (bin_centers + 0.5)
112+
else:
113+
bin_centers = bin_centers.to(device=tm_logits.device, dtype=tm_logits.dtype)
114+
if bin_centers.numel() != num_bins:
115+
raise ValueError(
116+
f"bin_centers must have {num_bins} entries, got {bin_centers.numel()}"
117+
)
118+
119+
if residue_mask is None:
120+
residue_mask = torch.ones(batch_size, length, device=tm_logits.device, dtype=tm_logits.dtype)
121+
else:
122+
residue_mask = residue_mask.to(device=tm_logits.device, dtype=tm_logits.dtype)
123+
124+
num_res = residue_mask.sum(dim=-1).clamp_min(1.0)
125+
d0 = 1.24 * torch.clamp(num_res, min=19.0).sub(15.0).pow(1.0 / 3.0) - 1.8
126+
d0 = d0.clamp_min(0.5)
127+
128+
probs = F.softmax(tm_logits, dim=-1)
129+
tm_kernel = 1.0 / (1.0 + (bin_centers.view(1, 1, 1, -1) / (d0.view(-1, 1, 1, 1) + eps)) ** 2)
130+
expected_tm = (probs * tm_kernel).sum(dim=-1)
131+
132+
per_alignment = (expected_tm * residue_mask[:, None, :]).sum(dim=-1) / num_res.view(-1, 1)
133+
per_alignment = per_alignment.masked_fill(residue_mask <= 0, float("-inf"))
134+
135+
ptm = per_alignment.max(dim=-1).values
136+
has_valid = residue_mask.sum(dim=-1) > 0
137+
ptm = torch.where(has_valid, ptm, torch.zeros_like(ptm))
138+
return ptm
139+
140+
141+
class TMHead(nn.Module):
142+
def __init__(self, c_z=128, num_bins=64, max_error=31.5):
143+
super().__init__()
144+
self.num_bins = int(num_bins)
145+
self.max_error = float(max_error)
146+
self.ln = nn.LayerNorm(c_z)
147+
self.linear = nn.Linear(c_z, self.num_bins)
148+
149+
if self.num_bins <= 1:
150+
bin_width = 0.5
151+
else:
152+
bin_width = self.max_error / float(self.num_bins - 1)
153+
bin_centers = bin_width * (torch.arange(self.num_bins, dtype=torch.float32) + 0.5)
154+
self.register_buffer("bin_centers", bin_centers, persistent=False)
155+
156+
def compute_ptm(self, tm_logits, residue_mask=None):
157+
return compute_predicted_tm_score(
158+
tm_logits,
159+
residue_mask=residue_mask,
160+
bin_centers=self.bin_centers,
161+
)
162+
163+
def forward(self, z, residue_mask=None):
164+
logits = self.linear(self.ln(z))
165+
ptm = self.compute_ptm(logits, residue_mask=residue_mask)
166+
return logits, ptm

scripts/ablations/run_suite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def _write_comparison_tables(rows: list[dict], output_dir: Path) -> None:
148148
"msa_loss",
149149
"plddt_loss",
150150
"torsion_loss",
151+
"ptm_logged",
151152
"rmsd_logged",
152153
"tm_score_logged",
153154
"gdt_ts_logged",
@@ -200,6 +201,7 @@ def main(argv: Sequence[str] | None = None) -> None:
200201
"msa_loss": stats.get("msa_loss"),
201202
"plddt_loss": stats.get("plddt_loss"),
202203
"torsion_loss": stats.get("torsion_loss"),
204+
"ptm_logged": stats.get("ptm_logged"),
203205
"rmsd_logged": stats.get("rmsd_logged"),
204206
"tm_score_logged": stats.get("tm_score_logged"),
205207
"gdt_ts_logged": stats.get("gdt_ts_logged"),

tests/test_ablation_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def test_alphafold2_ablation_defaults_are_explicit_and_baseline_safe():
187187
assert AlphaFold2.resolve_ablation_defaults(2)["recycle_single_enabled"] is False
188188
assert AlphaFold2.resolve_ablation_defaults(3)["masked_msa_head_enabled"] is False
189189
assert AlphaFold2.resolve_ablation_defaults(3)["plddt_head_enabled"] is False
190+
assert AlphaFold2.resolve_ablation_defaults(3)["tm_head_enabled"] is False
190191
assert AlphaFold2.resolve_ablation_defaults(4)["use_block_specific_params"] is True
191192
assert AlphaFold2.resolve_ablation_defaults(5)["recycle_single_enabled"] is False
192193
assert AlphaFold2.resolve_ablation_defaults(5)["evoformer_enabled"] is False

tests/test_forward_model.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import torch
66

7+
from model.alphafold2 import AlphaFold2
8+
79

810
def test_alphafold2_forward_smoke(toy_model, toy_batch):
911
with torch.no_grad():
@@ -28,6 +30,8 @@ def test_alphafold2_forward_smoke(toy_model, toy_batch):
2830
assert outputs["plddt"].shape == (batch_size, length)
2931
assert outputs["distogram_logits"].shape == (batch_size, length, length, 64)
3032
assert outputs["masked_msa_logits"].shape == (batch_size, toy_batch["msa_tokens"].shape[1], length, 23)
33+
assert outputs["tm_logits"] is None
34+
assert outputs["ptm"] is None
3135

3236
for value in outputs.values():
3337
if torch.is_tensor(value):
@@ -41,6 +45,45 @@ def test_alphafold2_forward_smoke(toy_model, toy_batch):
4145
assert torch.all((outputs["plddt"] >= 0.0) & (outputs["plddt"] <= 100.0))
4246

4347

48+
def test_alphafold2_tm_head_can_be_enabled(toy_batch):
49+
torch.manual_seed(11)
50+
model = AlphaFold2(
51+
n_tokens=27,
52+
c_m=256,
53+
c_z=128,
54+
c_s=256,
55+
max_relpos=32,
56+
pad_idx=0,
57+
num_evoformer_blocks=1,
58+
num_structure_blocks=1,
59+
transition_expansion_evoformer=2,
60+
transition_expansion_structure=2,
61+
use_block_specific_params=False,
62+
dist_bins=64,
63+
plddt_bins=50,
64+
tm_num_bins=64,
65+
tm_head_enabled=True,
66+
n_torsions=3,
67+
num_res_blocks_torsion=1,
68+
).eval()
69+
70+
with torch.no_grad():
71+
outputs = model(
72+
seq_tokens=toy_batch["seq_tokens"],
73+
msa_tokens=toy_batch["msa_tokens"],
74+
seq_mask=toy_batch["seq_mask"],
75+
msa_mask=toy_batch["msa_mask"],
76+
ideal_backbone_local=toy_batch["ideal_backbone_local"],
77+
)
78+
79+
batch_size, length = toy_batch["seq_tokens"].shape
80+
assert outputs["tm_logits"].shape == (batch_size, length, length, 64)
81+
assert outputs["ptm"].shape == (batch_size,)
82+
assert torch.isfinite(outputs["tm_logits"]).all()
83+
assert torch.isfinite(outputs["ptm"]).all()
84+
assert torch.all((outputs["ptm"] >= 0.0) & (outputs["ptm"] <= 1.0))
85+
86+
4487
def test_alphafold_loss_orchestrator_returns_finite_components(toy_model, toy_batch, toy_criterion):
4588
with torch.no_grad():
4689
outputs = toy_model(

0 commit comments

Comments
 (0)