-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptimize.txt
More file actions
54 lines (38 loc) · 6.21 KB
/
Copy pathoptimize.txt
File metadata and controls
54 lines (38 loc) · 6.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
Options, roughly by effort/impact:
┌─────┬─────────────────────────────────────────────────────────────────────────┬───────────────┬────────────────┐
│ # │ Optimization │ Expected TPS │ Effort │
├─────┼─────────────────────────────────────────────────────────────────────────┼───────────────┼────────────────┤
│ 1 │ Tensor compression — FP16→FP8 for hidden state transport (halves TCP │ ~21 │ Small │
│ │ time) │ │ │
├─────┼─────────────────────────────────────────────────────────────────────────┼───────────────┼────────────────┤
│ 2 │ KV cache prefix reuse — persist KV caches across multi-turn │ huge for chat │ Medium │
│ │ conversations │ │ │
├─────┼─────────────────────────────────────────────────────────────────────────┼───────────────┼────────────────┤
│ 3 │ Add 2nd worker — 3-node sharding balances load better (53+13 is │ ~25-28 │ Medium │
│ │ lopsided) │ │ (hardware) │
├─────┼─────────────────────────────────────────────────────────────────────────┼───────────────┼────────────────┤
│ 4 │ Continuous batching — process multiple inference requests concurrently │ 2-3x │ Large │
│ │ │ throughput │ │
├─────┼───────────────────────────────────────────────j ──────────────────────────┼───────────────┼────────────────┤
│ 5 │ Speculative decoding with draft model — small model on coord predicts N │ ~22-25 │ Large │
│ │ tokens, worker verifies batch │ │ │
├─────┼─────────────────────────────────────────────────────────────────────────┼───────────────┼────────────────┤
│ 6 │ RDMA/JACCL transport — bypass TCP stack entirely (already compiled, │ ~22 │ Small (config) │
│ │ needs rdma_ctl enable) │ │ │
└─────┴─────────────────────────────────────────────────────────────────────────┴───────────────┴────────────────┘
Other Optimization Options
1. Compute/Communication Overlap (Double Buffering)
Instead of waiting for worker to return, coord starts computing the next token's layers while worker processes the current one. This requires pipelining at the token level — send token N's hidden to worker, then immediately start
computing token N+1's hidden on coord. When worker's result comes back, use it for token N+2. This could theoretically overlap ~30ms of coord compute with ~33ms of worker compute, approaching the max of either alone (~30 tok/s).
2. Asynchronous Batched Decode
Buffer multiple tokens and send them as a batch to the worker. Worker processes all K tokens in one forward pass (amortizing TCP overhead from 6ms×K to 6ms+compute×K). This is basically what speculative decode does but without the
guesswork — just queue up K tokens of hidden states.
3. Reduce Coord Layers (Rebalance Sharding)
Currently coord has 53 layers and worker has 13. If we moved more layers to the worker (e.g., 40/26 split), coord would be faster (~20ms) and worker slower (~45ms). This doesn't help the sequential case, but if we can overlap (option
1), a balanced split maximizes throughput.
4. Tensor Compression / Quantized Transport
The hidden state sent over TCP is [1, 1, 2048] float32 = ~8KB per token. For the speculator's combined hidden it's [1, K+1, 2048] = ~40KB. FP8 or FP16 quantization could halve/quarter this, reducing TCP latency.
5. KV Cache Quantization
Compress KV cache on coord to free more GPU memory, allowing larger batch sizes or longer context windows. Doesn't directly improve single-stream TPS but improves memory efficiency.
6. Kernel Fusion on Worker
The worker only runs 13 layers + norm + head. Fusing norm + head + argmax into a single kernel would reduce its latency from 33ms to maybe 20ms, which directly translates to higher TPS.