-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Expand file tree
/
Copy pathmodel.py
More file actions
107 lines (87 loc) · 4.33 KB
/
Copy pathmodel.py
File metadata and controls
107 lines (87 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""WiFi-CSI pose model + LoRA adapter for the RuView calibration service.
Architecture matches the published flagship checkpoint
[`ruvnet/wifi-densepose-mmfi-pose`](https://huggingface.co/ruvnet/wifi-densepose-mmfi-pose)
(`pose_mmfi_best.pt`): transformer encoder + temporal attention pooling + skeleton-graph head.
The calibration service freezes this base and fits a tiny per-room **LoRA adapter** (rank 8 on the
input projection + pose head ≈ 11 KB) from ~100–200 labeled in-room samples. Empirically that lifts
cross-subject 64→72% and cross-environment 11→73% (ADR-150 §3.3–3.6).
"""
import numpy as np
import torch
import torch.nn as nn
# COCO-17 skeleton edges for the graph-refinement head.
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
(5, 11), (6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)]
_A = np.eye(17, dtype=np.float32)
for _i, _j in EDGES:
_A[_i, _j] = _A[_j, _i] = 1.0
_A = _A / _A.sum(1, keepdims=True)
class LoRA(nn.Module):
"""Low-rank adapter wrapping a frozen Linear: y = W·x + (x·A·B)·(alpha/r)."""
def __init__(self, base: nn.Linear, r: int = 8, alpha: int = 16):
super().__init__()
self.base = base
for p in self.base.parameters():
p.requires_grad = False
self.A = nn.Parameter(torch.zeros(base.in_features, r))
self.B = nn.Parameter(torch.zeros(r, base.out_features))
nn.init.normal_(self.A, std=0.02)
self.scale = alpha / r
def forward(self, x):
return self.base(x) + (x @ self.A @ self.B) * self.scale
class GR(nn.Module):
"""Skeleton-graph refinement: nudges joints toward anatomically consistent positions."""
def __init__(self, d=256, h=96):
super().__init__()
self.je = nn.Parameter(torch.randn(17, 32) * 0.02)
self.inp = nn.Linear(d + 34, h)
self.g1 = nn.Linear(h, h)
self.g2 = nn.Linear(h, h)
self.out = nn.Linear(h, 2)
self.register_buffer("A", torch.tensor(_A))
def forward(self, z, kp0):
B = z.shape[0]
f = torch.relu(self.inp(torch.cat(
[z.unsqueeze(1).expand(-1, 17, -1), self.je.unsqueeze(0).expand(B, -1, -1), kp0], -1)))
f = torch.relu(self.g1(torch.einsum('ij,bjh->bih', self.A, f)))
f = torch.relu(self.g2(torch.einsum('ij,bjh->bih', self.A, f)))
return kp0 + 0.3 * torch.tanh(self.out(f))
class PoseNet(nn.Module):
"""Flagship pose model. Input [B,3,114,10] CSI amplitude (per-sample standardized) -> [B,34]."""
def __init__(self, na=3, nsc=114, nt=10, d=256, L=4, H=8):
super().__init__()
self.proj = nn.Linear(na * nsc, d)
self.pos = nn.Parameter(torch.randn(1, nt, d) * 0.02)
enc = nn.TransformerEncoderLayer(d, H, d * 2, dropout=0.2, batch_first=True, activation='gelu')
self.tf = nn.TransformerEncoder(enc, L)
self.att = nn.Linear(d, 1)
self.head = nn.Sequential(nn.Linear(d, 256), nn.GELU(), nn.Dropout(0.3), nn.Linear(256, 34))
self.gr = GR(d)
self.na, self.nsc, self.nt = na, nsc, nt
def forward(self, x):
B = x.shape[0]
t = x.permute(0, 3, 1, 2).reshape(B, self.nt, self.na * self.nsc)
h = self.tf(self.proj(t) + self.pos)
w = torch.softmax(self.att(h), 1)
z = (h * w).sum(1)
kp0 = torch.sigmoid(self.head(z)).reshape(B, 17, 2)
return self.gr(z, kp0).reshape(B, 34)
def add_lora(self, r=8, alpha=16):
"""Wrap the input projection + pose head with LoRA adapters (the ~11 KB calibration set)."""
self.proj = LoRA(self.proj, r, alpha)
self.head[0] = LoRA(self.head[0], r, alpha)
self.head[3] = LoRA(self.head[3], r, alpha)
return self
def lora_state(self) -> dict:
"""Extract just the LoRA A/B tensors (the per-room adapter to save)."""
return {k: v.detach().cpu().numpy() for k, v in self.state_dict().items()
if k.endswith(".A") or k.endswith(".B")}
def load_lora(self, adapter: dict):
sd = self.state_dict()
for k, v in adapter.items():
sd[k] = torch.tensor(v)
self.load_state_dict(sd)
return self
def standardize(x: torch.Tensor) -> torch.Tensor:
"""Per-sample standardization used in training/inference."""
return (x - x.mean((1, 2, 3), keepdim=True)) / (x.std((1, 2, 3), keepdim=True) + 1e-6)