|
2 | 2 | import platform |
3 | 3 | import time |
4 | 4 |
|
5 | | -import einops |
6 | 5 | from packaging import version |
7 | 6 | import pytest |
8 | 7 | import torch |
@@ -337,59 +336,6 @@ def test_stable_embedding(): |
337 | 336 | layer.reset_parameters() |
338 | 337 |
|
339 | 338 |
|
340 | | -def quant(x): |
341 | | - max1 = torch.abs(x).max() |
342 | | - x = torch.round(x / max1 * 127) |
343 | | - return max1, x.to(torch.int8) |
344 | | - |
345 | | - |
346 | | -def dequant(c, maxC): |
347 | | - return c.float() * (maxC / 127) |
348 | | - |
349 | | - |
350 | | -def mm_dequant(maxA, maxB, C): |
351 | | - return C.float() * (maxA / 127) * (maxB / 127) |
352 | | - |
353 | | - |
354 | | -def quant_multi(x, dim): |
355 | | - max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) |
356 | | - max1[max1 == 0] = 1.0 |
357 | | - x = torch.round(x / max1 * 127) |
358 | | - return max1, x.to(torch.int8) |
359 | | - |
360 | | - |
361 | | -def quant_multi_chunk(x, dim, chunk_size=32): |
362 | | - if dim == 1: |
363 | | - x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size) |
364 | | - max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True) |
365 | | - max1 = torch.tile(max1, (1, 1, x.shape[1])) |
366 | | - max1 = max1.view(x.shape) |
367 | | - elif dim == 0: |
368 | | - x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size) |
369 | | - max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True) |
370 | | - max1 = torch.tile(max1, (x.shape[0], 1, 1)) |
371 | | - max1 = max1.view(x.shape) |
372 | | - max1[max1 == 0] = 1.0 |
373 | | - x = torch.round(x / max1 * 127) |
374 | | - return max1, x.to(torch.int8) |
375 | | - |
376 | | - |
377 | | -def mean(xx): |
378 | | - return sum(xx) / float(len(xx)) |
379 | | - |
380 | | - |
381 | | -methods = { |
382 | | - "linear": ( |
383 | | - lambda x, dim: quant(x), |
384 | | - lambda x, dim: quant(x), |
385 | | - dequant, |
386 | | - dequant, |
387 | | - mm_dequant, |
388 | | - ), |
389 | | - "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant), |
390 | | -} |
391 | | - |
392 | | - |
393 | 339 | class TestLLMInt8Functional: |
394 | 340 | @staticmethod |
395 | 341 | def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half): |
|
0 commit comments