Skip to content

Commit 4875da1

Browse files
committed
rm useless function
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent a8cd3c4 commit 4875da1

1 file changed

Lines changed: 0 additions & 54 deletions

File tree

tests/test_functional.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import platform
33
import time
44

5-
import einops
65
from packaging import version
76
import pytest
87
import torch
@@ -337,59 +336,6 @@ def test_stable_embedding():
337336
layer.reset_parameters()
338337

339338

340-
def quant(x):
341-
max1 = torch.abs(x).max()
342-
x = torch.round(x / max1 * 127)
343-
return max1, x.to(torch.int8)
344-
345-
346-
def dequant(c, maxC):
347-
return c.float() * (maxC / 127)
348-
349-
350-
def mm_dequant(maxA, maxB, C):
351-
return C.float() * (maxA / 127) * (maxB / 127)
352-
353-
354-
def quant_multi(x, dim):
355-
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
356-
max1[max1 == 0] = 1.0
357-
x = torch.round(x / max1 * 127)
358-
return max1, x.to(torch.int8)
359-
360-
361-
def quant_multi_chunk(x, dim, chunk_size=32):
362-
if dim == 1:
363-
x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
364-
max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
365-
max1 = torch.tile(max1, (1, 1, x.shape[1]))
366-
max1 = max1.view(x.shape)
367-
elif dim == 0:
368-
x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
369-
max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
370-
max1 = torch.tile(max1, (x.shape[0], 1, 1))
371-
max1 = max1.view(x.shape)
372-
max1[max1 == 0] = 1.0
373-
x = torch.round(x / max1 * 127)
374-
return max1, x.to(torch.int8)
375-
376-
377-
def mean(xx):
378-
return sum(xx) / float(len(xx))
379-
380-
381-
methods = {
382-
"linear": (
383-
lambda x, dim: quant(x),
384-
lambda x, dim: quant(x),
385-
dequant,
386-
dequant,
387-
mm_dequant,
388-
),
389-
"vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant),
390-
}
391-
392-
393339
class TestLLMInt8Functional:
394340
@staticmethod
395341
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half):

0 commit comments

Comments
 (0)