|
| 1 | +# HIGGS Dynamic Bitwidth Quantization |
| 2 | + |
| 3 | +This implements Section 5 from the HIGGS paper: **Variable Bitwidth Quantization** with data-free calibration. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +The algorithm finds optimal per-layer bitwidths that minimize perplexity degradation while meeting a target average bitrate. It consists of three steps: |
| 8 | + |
| 9 | +### Step 1: Algorithm 3 — Estimate α_l Coefficients (Data-Free) |
| 10 | + |
| 11 | +Instead of measuring perplexity on a calibration dataset, we use **KL divergence on random tokens**: |
| 12 | + |
| 13 | +1. Generate random token IDs (no dataset needed) |
| 14 | +2. Run through the clean model → get clean logits |
| 15 | +3. For each layer `l` and noise level `t_j`: |
| 16 | + - Add Gaussian noise `N(0, t_j²)` to layer `l`'s weights |
| 17 | + - Run through noised model → get noised logits |
| 18 | + - Compute `KL(p_clean || p_noised)` |
| 19 | +4. Fit α_l via least squares: minimize `Σ_j (ΔKL_j - α_l · t²_j)²` |
| 20 | + |
| 21 | +**Why this works:** We don't need coherent text to measure how much a layer perturbation distorts the output distribution. Random tokens are sufficient because we're measuring relative distortion, not absolute quality. |
| 22 | + |
| 23 | +### Step 2: Compute Quantization Error Table |
| 24 | + |
| 25 | +For each layer `l` and each quantization option `j`, compute: |
| 26 | +- `t²_{l,j} = MSE of quantizing layer l with option j` |
| 27 | + |
| 28 | +This is purely about weight quantization error — no model inference needed. |
| 29 | + |
| 30 | +### Step 3: Optimize Bitwidth Assignment |
| 31 | + |
| 32 | +Solve the constrained optimization: |
| 33 | + |
| 34 | +``` |
| 35 | +min Σ_l α_l · t²_{l,j_l} |
| 36 | +s.t. Σ_l b_{j_l} · d_l ≤ b_max · d |
| 37 | +``` |
| 38 | + |
| 39 | +where: |
| 40 | +- `α_l` = layer sensitivity coefficient (from Step 1) |
| 41 | +- `t²_{l,j}` = quantization error for layer l with option j (from Step 2) |
| 42 | +- `b_j` = bits per element for option j |
| 43 | +- `d_l` = number of elements in layer l |
| 44 | +- `b_max` = target average bitrate |
| 45 | + |
| 46 | +We use a **greedy knapsack approach**: iteratively upgrade the layer with the best marginal benefit (PPL improvement per bit spent) until the bit budget is exhausted. |
| 47 | + |
| 48 | +## Usage |
| 49 | + |
| 50 | +### Quick Start — Run All Steps |
| 51 | + |
| 52 | +```bash |
| 53 | +cd baselines/opt_sym |
| 54 | + |
| 55 | +# Single target bitrate |
| 56 | +python dynamic_bitwidth.py \ |
| 57 | + --model meta-llama/Llama-3.1-8B \ |
| 58 | + --seqlen 2048 \ |
| 59 | + --calibrate \ |
| 60 | + --compute-error-table \ |
| 61 | + --optimize \ |
| 62 | + --target-bits 3.0 \ |
| 63 | + --bitwidth-options "k=2,p=2;k=3,p=2;k=4,p=2" |
| 64 | +``` |
| 65 | + |
| 66 | +This will create: |
| 67 | +- `alphas.json` — Per-layer sensitivity coefficients |
| 68 | +- `error_table.json` — Quantization error for each layer/option |
| 69 | +- `assignment.json` — Optimized bitwidth assignment |
| 70 | + |
| 71 | +### Step-by-Step |
| 72 | + |
| 73 | +**Step 1: Calibrate α_l** |
| 74 | + |
| 75 | +```bash |
| 76 | +python dynamic_bitwidth.py \ |
| 77 | + --model meta-llama/Llama-3.1-8B \ |
| 78 | + --seqlen 2048 \ |
| 79 | + --calibrate \ |
| 80 | + --calibration-tokens 287000 \ |
| 81 | + --output-alpha alphas.json |
| 82 | +``` |
| 83 | + |
| 84 | +Options: |
| 85 | +- `--calibration-tokens`: Number of random tokens (default: 287k, matching HIGGS paper) |
| 86 | +- `--n-noise-levels`: Number of noise levels J (default: 15) |
| 87 | +- `--t-min`, `--t-max`: Noise range (default: 0.001 to 0.05) |
| 88 | + |
| 89 | +**Step 2: Compute Error Table** |
| 90 | + |
| 91 | +```bash |
| 92 | +python dynamic_bitwidth.py \ |
| 93 | + --model meta-llama/Llama-3.1-8B \ |
| 94 | + --compute-error-table \ |
| 95 | + --bitwidth-options "k=2,p=2;k=3,p=2;k=4,p=2" \ |
| 96 | + --norm l2 \ |
| 97 | + --error-table-path error_table.json |
| 98 | +``` |
| 99 | + |
| 100 | +Options: |
| 101 | +- `--bitwidth-options`: Semicolon-separated options, each as `k=X,p=Y` |
| 102 | +- `--norm`: `l2` (HIGGS-style) or `absmax` (BNF-style) |
| 103 | +- `--rot-blocksize`: Rotation block size for L2 norm (default: 128) |
| 104 | + |
| 105 | +**Step 3: Optimize Assignment** |
| 106 | + |
| 107 | +```bash |
| 108 | +python dynamic_bitwidth.py \ |
| 109 | + --model meta-llama/Llama-3.1-8B \ |
| 110 | + --optimize \ |
| 111 | + --alpha-path alphas.json \ |
| 112 | + --error-table-path error_table.json \ |
| 113 | + --target-bits 3.0 \ |
| 114 | + --output-assignment assignment.json |
| 115 | +``` |
| 116 | + |
| 117 | +### Batch Job on Babel |
| 118 | + |
| 119 | +```bash |
| 120 | +# Single target |
| 121 | +sbatch run_dynamic_bitwidth.sh \ |
| 122 | + --model meta-llama/Llama-3.1-8B \ |
| 123 | + --target-bits 3.0 |
| 124 | + |
| 125 | +# Array job for multiple bitrates |
| 126 | +sbatch --array=2.0,2.5,3.0,3.5,4.0 \ |
| 127 | + run_dynamic_bitwidth.sh \ |
| 128 | + --model meta-llama/Llama-3.1-8B |
| 129 | +``` |
| 130 | + |
| 131 | +## Output Format |
| 132 | + |
| 133 | +### alphas.json |
| 134 | + |
| 135 | +```json |
| 136 | +{ |
| 137 | + "alphas": { |
| 138 | + "0": 0.123, |
| 139 | + "1": 0.456, |
| 140 | + ... |
| 141 | + }, |
| 142 | + "n_layers": 32, |
| 143 | + "calibration_tokens": 287000, |
| 144 | + "n_noise_levels": 15, |
| 145 | + ... |
| 146 | +} |
| 147 | +``` |
| 148 | + |
| 149 | +### error_table.json |
| 150 | + |
| 151 | +```json |
| 152 | +{ |
| 153 | + "error_table": { |
| 154 | + "0": {"0": 0.0012, "1": 0.0008, "2": 0.0005}, |
| 155 | + ... |
| 156 | + }, |
| 157 | + "total_elements": {"0": 4194304, ...}, |
| 158 | + "options": [ |
| 159 | + {"k": 2, "p": 2, "index_bits": 4, "bits_per_entry": 2, "config_str": "k2p2"}, |
| 160 | + ... |
| 161 | + ], |
| 162 | + "norm": "l2", |
| 163 | + ... |
| 164 | +} |
| 165 | +``` |
| 166 | + |
| 167 | +### assignment.json |
| 168 | + |
| 169 | +```json |
| 170 | +{ |
| 171 | + "assignment": { |
| 172 | + "0": 2, // Layer 0 uses option index 2 (k=4,p=2) |
| 173 | + "1": 1, // Layer 1 uses option index 1 (k=3,p=2) |
| 174 | + ... |
| 175 | + }, |
| 176 | + "avg_bits": 3.012, |
| 177 | + "target_bits": 3.0, |
| 178 | + "expected_ppl_degradation": 0.0456, |
| 179 | + "option_counts": { |
| 180 | + "k2p2": 8, |
| 181 | + "k3p2": 18, |
| 182 | + "k4p2": 6 |
| 183 | + } |
| 184 | +} |
| 185 | +``` |
| 186 | + |
| 187 | +## Implementation Notes |
| 188 | + |
| 189 | +### Algorithm 3 (Data-Free Calibration) |
| 190 | + |
| 191 | +The key insight is that we can estimate per-layer sensitivity without a calibration dataset. Instead of: |
| 192 | + |
| 193 | +``` |
| 194 | +Δ_{l,j} = PPL(W*(l, t_j)) - PPL(W*) // needs WikiText-2 |
| 195 | +``` |
| 196 | + |
| 197 | +We use: |
| 198 | + |
| 199 | +``` |
| 200 | +Δ_{l,j} = KL(p_clean || p_noised) // works with random tokens |
| 201 | +``` |
| 202 | + |
| 203 | +This makes the calibration **fully self-contained** — no data download or preparation needed. |
| 204 | + |
| 205 | +### Computational Cost |
| 206 | + |
| 207 | +- **Calibration**: `L × J` forward passes = ~32 layers × 15 noise levels = 480 passes |
| 208 | +- **Error table**: One quantization per layer per option |
| 209 | +- **Optimization**: Negligible (analytical solution) |
| 210 | + |
| 211 | +For Llama-3.1-8B with 287k tokens at seqlen=2048: |
| 212 | +- Calibration: ~4-6 hours on single GPU |
| 213 | +- Error table: ~30 minutes |
| 214 | + |
| 215 | +### Extending to Per-Entry Quantization |
| 216 | + |
| 217 | +The current implementation assigns the same bitwidth to all weights in a layer. To extend to per-entry quantization (like HIGGS with per-entry sensitivity): |
| 218 | + |
| 219 | +1. Compute κ_l,i for each weight element (like existing κ analysis) |
| 220 | +2. Modify error table to track per-entry MSE |
| 221 | +3. Add constraint to the optimization ensuring each entry gets at least one bit |
| 222 | +4. Solve with a more sophisticated solver (e.g., water-filling) |
| 223 | + |
| 224 | +This is future work — the current layer-wise assignment is already a significant improvement over uniform bitwidths. |
0 commit comments