|
| 1 | +"""Vision encoder front end: patch embedding plus 2D sinusoidal position. |
| 2 | +
|
| 3 | +Tokenizes a 224x224x3 image into a sequence of 196 patch tokens plus a CLS |
| 4 | +token. The patch projection is a Conv2d with kernel and stride equal to the |
| 5 | +patch size, which is numerically identical to flatten-then-linear. The |
| 6 | +position signal is a fixed 2D sinusoidal table; half the embedding dim encodes |
| 7 | +row position, the other half encodes column position, at multiple frequencies. |
| 8 | +
|
| 9 | +Run with: python3 main.py |
| 10 | +""" |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import math |
| 15 | +from dataclasses import dataclass |
| 16 | + |
| 17 | +import numpy as np |
| 18 | +import torch |
| 19 | +import torch.nn as nn |
| 20 | + |
| 21 | + |
| 22 | +@dataclass(frozen=True) |
| 23 | +class FrontEndConfig: |
| 24 | + image_size: int = 224 |
| 25 | + patch_size: int = 16 |
| 26 | + in_channels: int = 3 |
| 27 | + hidden: int = 768 |
| 28 | + |
| 29 | + @property |
| 30 | + def grid_size(self) -> int: |
| 31 | + if self.image_size % self.patch_size != 0: |
| 32 | + raise ValueError( |
| 33 | + f"patch_size {self.patch_size} must divide image_size {self.image_size}" |
| 34 | + ) |
| 35 | + return self.image_size // self.patch_size |
| 36 | + |
| 37 | + @property |
| 38 | + def num_patches(self) -> int: |
| 39 | + return self.grid_size * self.grid_size |
| 40 | + |
| 41 | + |
| 42 | +def sinusoidal_2d(grid_h: int, grid_w: int, dim: int) -> torch.Tensor: |
| 43 | + """Build a deterministic 2D sinusoidal position table of shape (grid_h * grid_w, dim). |
| 44 | +
|
| 45 | + Half of dim encodes row position, half encodes column position. Within each |
| 46 | + half, frequencies span the standard Transformer sin/cos band. Identical |
| 47 | + inputs always produce identical outputs, with no learned state. |
| 48 | + """ |
| 49 | + if dim % 4 != 0: |
| 50 | + raise ValueError(f"sinusoidal_2d dim must be divisible by 4, got {dim}") |
| 51 | + half = dim // 2 |
| 52 | + quarter = half // 2 |
| 53 | + |
| 54 | + freq = torch.arange(quarter, dtype=torch.float32) |
| 55 | + inv = torch.exp(-math.log(10000.0) * freq / max(1, quarter)) |
| 56 | + |
| 57 | + rows = torch.arange(grid_h, dtype=torch.float32).unsqueeze(1) * inv.unsqueeze(0) |
| 58 | + cols = torch.arange(grid_w, dtype=torch.float32).unsqueeze(1) * inv.unsqueeze(0) |
| 59 | + |
| 60 | + row_emb = torch.cat([torch.sin(rows), torch.cos(rows)], dim=1) |
| 61 | + col_emb = torch.cat([torch.sin(cols), torch.cos(cols)], dim=1) |
| 62 | + |
| 63 | + table = torch.zeros(grid_h, grid_w, dim) |
| 64 | + table[:, :, :half] = row_emb.unsqueeze(1).expand(-1, grid_w, -1) |
| 65 | + table[:, :, half:] = col_emb.unsqueeze(0).expand(grid_h, -1, -1) |
| 66 | + return table.reshape(grid_h * grid_w, dim) |
| 67 | + |
| 68 | + |
| 69 | +class PatchEmbed(nn.Module): |
| 70 | + """Patch projection as a strided Conv2d. |
| 71 | +
|
| 72 | + Output shape on a (B, C, H, W) input is (B, N, hidden) where |
| 73 | + N = (H / patch_size) * (W / patch_size). |
| 74 | + """ |
| 75 | + |
| 76 | + def __init__(self, cfg: FrontEndConfig) -> None: |
| 77 | + super().__init__() |
| 78 | + self.cfg = cfg |
| 79 | + self.proj = nn.Conv2d( |
| 80 | + cfg.in_channels, |
| 81 | + cfg.hidden, |
| 82 | + kernel_size=cfg.patch_size, |
| 83 | + stride=cfg.patch_size, |
| 84 | + bias=True, |
| 85 | + ) |
| 86 | + |
| 87 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 88 | + if x.dim() != 4: |
| 89 | + raise ValueError(f"expected 4D input (B,C,H,W), got shape {tuple(x.shape)}") |
| 90 | + if x.shape[1] != self.cfg.in_channels: |
| 91 | + raise ValueError( |
| 92 | + f"channel mismatch: got {x.shape[1]}, expected {self.cfg.in_channels}" |
| 93 | + ) |
| 94 | + if x.shape[2] != self.cfg.image_size or x.shape[3] != self.cfg.image_size: |
| 95 | + raise ValueError( |
| 96 | + f"spatial mismatch: got {tuple(x.shape[2:])}, expected " |
| 97 | + f"({self.cfg.image_size}, {self.cfg.image_size})" |
| 98 | + ) |
| 99 | + out = self.proj(x) |
| 100 | + b = out.shape[0] |
| 101 | + out = out.flatten(2).transpose(1, 2) |
| 102 | + return out |
| 103 | + |
| 104 | + |
| 105 | +class VisionFrontEnd(nn.Module): |
| 106 | + """Patch embed + CLS prepend + 2D sinusoidal position. |
| 107 | +
|
| 108 | + Output shape: (B, num_patches + 1, hidden). |
| 109 | + """ |
| 110 | + |
| 111 | + def __init__(self, cfg: FrontEndConfig) -> None: |
| 112 | + super().__init__() |
| 113 | + self.cfg = cfg |
| 114 | + self.patch = PatchEmbed(cfg) |
| 115 | + self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.hidden)) |
| 116 | + nn.init.trunc_normal_(self.cls_token, std=0.02) |
| 117 | + |
| 118 | + pos = sinusoidal_2d(cfg.grid_size, cfg.grid_size, cfg.hidden) |
| 119 | + cls_pos = torch.zeros(1, cfg.hidden) |
| 120 | + full = torch.cat([cls_pos, pos], dim=0).unsqueeze(0) |
| 121 | + self.register_buffer("pos_embed", full, persistent=False) |
| 122 | + |
| 123 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 124 | + tokens = self.patch(x) |
| 125 | + b = tokens.shape[0] |
| 126 | + cls = self.cls_token.expand(b, -1, -1) |
| 127 | + tokens = torch.cat([cls, tokens], dim=1) |
| 128 | + tokens = tokens + self.pos_embed |
| 129 | + return tokens |
| 130 | + |
| 131 | + |
| 132 | +def synthesize_image(seed: int, image_size: int = 224, channels: int = 3) -> torch.Tensor: |
| 133 | + """Build a deterministic 1x3x224x224 fixture from numpy.random. |
| 134 | +
|
| 135 | + Values are in [0, 1] float32. Adding a smooth gradient on top of noise gives |
| 136 | + the patch projection something with both high and low frequency content to |
| 137 | + summarize. |
| 138 | + """ |
| 139 | + rng = np.random.default_rng(seed) |
| 140 | + noise = rng.standard_normal((channels, image_size, image_size)).astype("float32") * 0.1 |
| 141 | + y_coords = np.linspace(0.0, 1.0, image_size, dtype="float32") |
| 142 | + x_coords = np.linspace(0.0, 1.0, image_size, dtype="float32") |
| 143 | + gx, gy = np.meshgrid(x_coords, y_coords, indexing="xy") |
| 144 | + gradient = np.stack([gx, gy, (gx + gy) * 0.5], axis=0).astype("float32") |
| 145 | + img = np.clip(gradient + noise + 0.5, 0.0, 1.0) |
| 146 | + return torch.from_numpy(img).unsqueeze(0) |
| 147 | + |
| 148 | + |
| 149 | +def unfold_then_linear(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, patch_size: int) -> torch.Tensor: |
| 150 | + """Reference implementation of patch projection via unfold + matmul. |
| 151 | +
|
| 152 | + Used by the tests to assert that the Conv2d projection matches the |
| 153 | + flatten-then-linear math. |
| 154 | + """ |
| 155 | + if x.dim() != 4: |
| 156 | + raise ValueError(f"expected 4D input, got {tuple(x.shape)}") |
| 157 | + patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size) |
| 158 | + b, c, gh, gw, ph, pw = patches.shape |
| 159 | + flat = patches.permute(0, 2, 3, 1, 4, 5).reshape(b, gh * gw, c * ph * pw) |
| 160 | + w_flat = weight.reshape(weight.shape[0], -1) |
| 161 | + return flat @ w_flat.T + bias |
| 162 | + |
| 163 | + |
| 164 | +def describe_token_norms(tokens: torch.Tensor, max_show: int = 8) -> str: |
| 165 | + """Print the L2 norm of the first few tokens for sanity inspection.""" |
| 166 | + norms = tokens.detach().norm(dim=-1)[0].tolist() |
| 167 | + head = norms[:max_show] |
| 168 | + return ", ".join(f"{v:.3f}" for v in head) |
| 169 | + |
| 170 | + |
| 171 | +def main() -> None: |
| 172 | + print("=" * 60) |
| 173 | + print("VISION ENCODER PATCHES") |
| 174 | + print("=" * 60) |
| 175 | + |
| 176 | + cfg = FrontEndConfig() |
| 177 | + print(f" image size : {cfg.image_size}") |
| 178 | + print(f" patch size : {cfg.patch_size}") |
| 179 | + print(f" grid size : {cfg.grid_size}x{cfg.grid_size}") |
| 180 | + print(f" num patches: {cfg.num_patches}") |
| 181 | + print(f" hidden : {cfg.hidden}") |
| 182 | + print(f" seq length : {cfg.num_patches + 1} (includes CLS)") |
| 183 | + |
| 184 | + torch.manual_seed(0) |
| 185 | + img = synthesize_image(seed=0) |
| 186 | + print(f"\nfixture image shape : {tuple(img.shape)}") |
| 187 | + print(f"fixture image dtype : {img.dtype}") |
| 188 | + print(f"fixture pixel range : [{img.min().item():.3f}, {img.max().item():.3f}]") |
| 189 | + |
| 190 | + model = VisionFrontEnd(cfg).eval() |
| 191 | + n_params = sum(p.numel() for p in model.parameters()) |
| 192 | + print(f"\nfront-end params : {n_params:,}") |
| 193 | + |
| 194 | + with torch.no_grad(): |
| 195 | + tokens = model(img) |
| 196 | + |
| 197 | + print(f"output token shape : {tuple(tokens.shape)}") |
| 198 | + print(f"CLS token norm : {tokens[0, 0].norm().item():.3f}") |
| 199 | + print(f"first 8 token norms : {describe_token_norms(tokens)}") |
| 200 | + |
| 201 | + print("\nposition embedding row signature:") |
| 202 | + pos_row = model.pos_embed[0, 1, :8].tolist() |
| 203 | + print(" pos[1, :8] =", ", ".join(f"{v:+.3f}" for v in pos_row)) |
| 204 | + |
| 205 | + print("\nbatch consistency check:") |
| 206 | + img_b4 = synthesize_image(seed=1).repeat(4, 1, 1, 1) |
| 207 | + with torch.no_grad(): |
| 208 | + out_b4 = model(img_b4) |
| 209 | + print(f" batch=4 output shape: {tuple(out_b4.shape)}") |
| 210 | + drift = (out_b4 - out_b4[0:1]).abs().max().item() |
| 211 | + print(f" max drift across identical batch rows: {drift:.6f}") |
| 212 | + |
| 213 | + print("\nunfold reference vs Conv2d projection:") |
| 214 | + weight = model.patch.proj.weight.detach() |
| 215 | + bias = model.patch.proj.bias.detach() |
| 216 | + ref = unfold_then_linear(img, weight, bias, cfg.patch_size) |
| 217 | + conv = model.patch(img) |
| 218 | + diff = (ref - conv).abs().max().item() |
| 219 | + print(f" max abs diff : {diff:.6e}") |
| 220 | + if diff < 1e-4: |
| 221 | + print(" ok: unfold reference matches Conv2d to float tolerance") |
| 222 | + else: |
| 223 | + print(" FAIL: projection drifts from reference") |
| 224 | + |
| 225 | + print("\ndone.") |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + main() |
0 commit comments