Skip to content

Commit 6d81104

Browse files
TimDettmersclaude
andcommitted
Add HIGGS dynamic bitwidth implementation
- dynamic_bitwidth.py: Algorithm 3 calibration, error table computation, bitwidth optimization - DYNAMIC_BITWIDTH.md: Documentation - Helper scripts for Babel submission Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 201e561 commit 6d81104

4 files changed

Lines changed: 1642 additions & 0 deletions

File tree

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# HIGGS Dynamic Bitwidth Quantization
2+
3+
This implements Section 5 from the HIGGS paper: **Variable Bitwidth Quantization** with data-free calibration.
4+
5+
## Overview
6+
7+
The algorithm finds optimal per-layer bitwidths that minimize perplexity degradation while meeting a target average bitrate. It consists of three steps:
8+
9+
### Step 1: Algorithm 3 — Estimate α_l Coefficients (Data-Free)
10+
11+
Instead of measuring perplexity on a calibration dataset, we use **KL divergence on random tokens**:
12+
13+
1. Generate random token IDs (no dataset needed)
14+
2. Run through the clean model → get clean logits
15+
3. For each layer `l` and noise level `t_j`:
16+
- Add Gaussian noise `N(0, t_j²)` to layer `l`'s weights
17+
- Run through noised model → get noised logits
18+
- Compute `KL(p_clean || p_noised)`
19+
4. Fit α_l via least squares: minimize `Σ_j (ΔKL_j - α_l · t²_j)²`
20+
21+
**Why this works:** We don't need coherent text to measure how much a layer perturbation distorts the output distribution. Random tokens are sufficient because we're measuring relative distortion, not absolute quality.
22+
23+
### Step 2: Compute Quantization Error Table
24+
25+
For each layer `l` and each quantization option `j`, compute:
26+
- `t²_{l,j} = MSE of quantizing layer l with option j`
27+
28+
This is purely about weight quantization error — no model inference needed.
29+
30+
### Step 3: Optimize Bitwidth Assignment
31+
32+
Solve the constrained optimization:
33+
34+
```
35+
min Σ_l α_l · t²_{l,j_l}
36+
s.t. Σ_l b_{j_l} · d_l ≤ b_max · d
37+
```
38+
39+
where:
40+
- `α_l` = layer sensitivity coefficient (from Step 1)
41+
- `t²_{l,j}` = quantization error for layer l with option j (from Step 2)
42+
- `b_j` = bits per element for option j
43+
- `d_l` = number of elements in layer l
44+
- `b_max` = target average bitrate
45+
46+
We use a **greedy knapsack approach**: iteratively upgrade the layer with the best marginal benefit (PPL improvement per bit spent) until the bit budget is exhausted.
47+
48+
## Usage
49+
50+
### Quick Start — Run All Steps
51+
52+
```bash
53+
cd baselines/opt_sym
54+
55+
# Single target bitrate
56+
python dynamic_bitwidth.py \
57+
--model meta-llama/Llama-3.1-8B \
58+
--seqlen 2048 \
59+
--calibrate \
60+
--compute-error-table \
61+
--optimize \
62+
--target-bits 3.0 \
63+
--bitwidth-options "k=2,p=2;k=3,p=2;k=4,p=2"
64+
```
65+
66+
This will create:
67+
- `alphas.json` — Per-layer sensitivity coefficients
68+
- `error_table.json` — Quantization error for each layer/option
69+
- `assignment.json` — Optimized bitwidth assignment
70+
71+
### Step-by-Step
72+
73+
**Step 1: Calibrate α_l**
74+
75+
```bash
76+
python dynamic_bitwidth.py \
77+
--model meta-llama/Llama-3.1-8B \
78+
--seqlen 2048 \
79+
--calibrate \
80+
--calibration-tokens 287000 \
81+
--output-alpha alphas.json
82+
```
83+
84+
Options:
85+
- `--calibration-tokens`: Number of random tokens (default: 287k, matching HIGGS paper)
86+
- `--n-noise-levels`: Number of noise levels J (default: 15)
87+
- `--t-min`, `--t-max`: Noise range (default: 0.001 to 0.05)
88+
89+
**Step 2: Compute Error Table**
90+
91+
```bash
92+
python dynamic_bitwidth.py \
93+
--model meta-llama/Llama-3.1-8B \
94+
--compute-error-table \
95+
--bitwidth-options "k=2,p=2;k=3,p=2;k=4,p=2" \
96+
--norm l2 \
97+
--error-table-path error_table.json
98+
```
99+
100+
Options:
101+
- `--bitwidth-options`: Semicolon-separated options, each as `k=X,p=Y`
102+
- `--norm`: `l2` (HIGGS-style) or `absmax` (BNF-style)
103+
- `--rot-blocksize`: Rotation block size for L2 norm (default: 128)
104+
105+
**Step 3: Optimize Assignment**
106+
107+
```bash
108+
python dynamic_bitwidth.py \
109+
--model meta-llama/Llama-3.1-8B \
110+
--optimize \
111+
--alpha-path alphas.json \
112+
--error-table-path error_table.json \
113+
--target-bits 3.0 \
114+
--output-assignment assignment.json
115+
```
116+
117+
### Batch Job on Babel
118+
119+
```bash
120+
# Single target
121+
sbatch run_dynamic_bitwidth.sh \
122+
--model meta-llama/Llama-3.1-8B \
123+
--target-bits 3.0
124+
125+
# Array job for multiple bitrates
126+
sbatch --array=2.0,2.5,3.0,3.5,4.0 \
127+
run_dynamic_bitwidth.sh \
128+
--model meta-llama/Llama-3.1-8B
129+
```
130+
131+
## Output Format
132+
133+
### alphas.json
134+
135+
```json
136+
{
137+
"alphas": {
138+
"0": 0.123,
139+
"1": 0.456,
140+
...
141+
},
142+
"n_layers": 32,
143+
"calibration_tokens": 287000,
144+
"n_noise_levels": 15,
145+
...
146+
}
147+
```
148+
149+
### error_table.json
150+
151+
```json
152+
{
153+
"error_table": {
154+
"0": {"0": 0.0012, "1": 0.0008, "2": 0.0005},
155+
...
156+
},
157+
"total_elements": {"0": 4194304, ...},
158+
"options": [
159+
{"k": 2, "p": 2, "index_bits": 4, "bits_per_entry": 2, "config_str": "k2p2"},
160+
...
161+
],
162+
"norm": "l2",
163+
...
164+
}
165+
```
166+
167+
### assignment.json
168+
169+
```json
170+
{
171+
"assignment": {
172+
"0": 2, // Layer 0 uses option index 2 (k=4,p=2)
173+
"1": 1, // Layer 1 uses option index 1 (k=3,p=2)
174+
...
175+
},
176+
"avg_bits": 3.012,
177+
"target_bits": 3.0,
178+
"expected_ppl_degradation": 0.0456,
179+
"option_counts": {
180+
"k2p2": 8,
181+
"k3p2": 18,
182+
"k4p2": 6
183+
}
184+
}
185+
```
186+
187+
## Implementation Notes
188+
189+
### Algorithm 3 (Data-Free Calibration)
190+
191+
The key insight is that we can estimate per-layer sensitivity without a calibration dataset. Instead of:
192+
193+
```
194+
Δ_{l,j} = PPL(W*(l, t_j)) - PPL(W*) // needs WikiText-2
195+
```
196+
197+
We use:
198+
199+
```
200+
Δ_{l,j} = KL(p_clean || p_noised) // works with random tokens
201+
```
202+
203+
This makes the calibration **fully self-contained** — no data download or preparation needed.
204+
205+
### Computational Cost
206+
207+
- **Calibration**: `L × J` forward passes = ~32 layers × 15 noise levels = 480 passes
208+
- **Error table**: One quantization per layer per option
209+
- **Optimization**: Negligible (analytical solution)
210+
211+
For Llama-3.1-8B with 287k tokens at seqlen=2048:
212+
- Calibration: ~4-6 hours on single GPU
213+
- Error table: ~30 minutes
214+
215+
### Extending to Per-Entry Quantization
216+
217+
The current implementation assigns the same bitwidth to all weights in a layer. To extend to per-entry quantization (like HIGGS with per-entry sensitivity):
218+
219+
1. Compute κ_l,i for each weight element (like existing κ analysis)
220+
2. Modify error table to track per-entry MSE
221+
3. Add constraint to the optimization ensuring each entry gets at least one bit
222+
4. Solve with a more sophisticated solver (e.g., water-filling)
223+
224+
This is future work — the current layer-wise assignment is already a significant improvement over uniform bitwidths.

0 commit comments

Comments
 (0)