Skip to content

Commit 9a8712d

Browse files
committed
Implement correct fp32 access pattern
1 parent 9c042ea commit 9a8712d

1 file changed

Lines changed: 79 additions & 42 deletions

File tree

csrc/kernels.cu

Lines changed: 79 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,48 +1553,85 @@ __global__ void kgemm_4bit_inference_naive(
15531553
float b31 = quant_map[qm_offset + (local_B_4bit[15] & 0xF)] * local_absmax;
15541554

15551555
if (__builtin_expect(inner_idx + 32 <= K, 1)) {
1556-
int4 a_vec0 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8];
1557-
int4 a_vec1 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 1];
1558-
int4 a_vec2 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 2];
1559-
int4 a_vec3 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 3];
1560-
1561-
const T* a0 = reinterpret_cast<const T*>(&a_vec0);
1562-
const T* a1 = reinterpret_cast<const T*>(&a_vec1);
1563-
const T* a2 = reinterpret_cast<const T*>(&a_vec2);
1564-
const T* a3 = reinterpret_cast<const T*>(&a_vec3);
1565-
1566-
local_C0 += (float)a0[0] * b0;
1567-
local_C1 += (float)a0[1] * b1;
1568-
local_C2 += (float)a0[2] * b2;
1569-
local_C3 += (float)a0[3] * b3;
1570-
local_C0 += (float)a0[4] * b4;
1571-
local_C1 += (float)a0[5] * b5;
1572-
local_C2 += (float)a0[6] * b6;
1573-
local_C3 += (float)a0[7] * b7;
1574-
local_C0 += (float)a1[0] * b8;
1575-
local_C1 += (float)a1[1] * b9;
1576-
local_C2 += (float)a1[2] * b10;
1577-
local_C3 += (float)a1[3] * b11;
1578-
local_C0 += (float)a1[4] * b12;
1579-
local_C1 += (float)a1[5] * b13;
1580-
local_C2 += (float)a1[6] * b14;
1581-
local_C3 += (float)a1[7] * b15;
1582-
local_C0 += (float)a2[0] * b16;
1583-
local_C1 += (float)a2[1] * b17;
1584-
local_C2 += (float)a2[2] * b18;
1585-
local_C3 += (float)a2[3] * b19;
1586-
local_C0 += (float)a2[4] * b20;
1587-
local_C1 += (float)a2[5] * b21;
1588-
local_C2 += (float)a2[6] * b22;
1589-
local_C3 += (float)a2[7] * b23;
1590-
local_C0 += (float)a3[0] * b24;
1591-
local_C1 += (float)a3[1] * b25;
1592-
local_C2 += (float)a3[2] * b26;
1593-
local_C3 += (float)a3[3] * b27;
1594-
local_C0 += (float)a3[4] * b28;
1595-
local_C1 += (float)a3[5] * b29;
1596-
local_C2 += (float)a3[6] * b30;
1597-
local_C3 += (float)a3[7] * b31;
1556+
if constexpr (BITS == 16) {
1557+
int4 a_vec0 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8];
1558+
int4 a_vec1 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 1];
1559+
int4 a_vec2 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 2];
1560+
int4 a_vec3 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 3];
1561+
1562+
const T* a0 = reinterpret_cast<const T*>(&a_vec0);
1563+
const T* a1 = reinterpret_cast<const T*>(&a_vec1);
1564+
const T* a2 = reinterpret_cast<const T*>(&a_vec2);
1565+
const T* a3 = reinterpret_cast<const T*>(&a_vec3);
1566+
1567+
local_C0 += (float)a0[0] * b0;
1568+
local_C1 += (float)a0[1] * b1;
1569+
local_C2 += (float)a0[2] * b2;
1570+
local_C3 += (float)a0[3] * b3;
1571+
local_C0 += (float)a0[4] * b4;
1572+
local_C1 += (float)a0[5] * b5;
1573+
local_C2 += (float)a0[6] * b6;
1574+
local_C3 += (float)a0[7] * b7;
1575+
local_C0 += (float)a1[0] * b8;
1576+
local_C1 += (float)a1[1] * b9;
1577+
local_C2 += (float)a1[2] * b10;
1578+
local_C3 += (float)a1[3] * b11;
1579+
local_C0 += (float)a1[4] * b12;
1580+
local_C1 += (float)a1[5] * b13;
1581+
local_C2 += (float)a1[6] * b14;
1582+
local_C3 += (float)a1[7] * b15;
1583+
local_C0 += (float)a2[0] * b16;
1584+
local_C1 += (float)a2[1] * b17;
1585+
local_C2 += (float)a2[2] * b18;
1586+
local_C3 += (float)a2[3] * b19;
1587+
local_C0 += (float)a2[4] * b20;
1588+
local_C1 += (float)a2[5] * b21;
1589+
local_C2 += (float)a2[6] * b22;
1590+
local_C3 += (float)a2[7] * b23;
1591+
local_C0 += (float)a3[0] * b24;
1592+
local_C1 += (float)a3[1] * b25;
1593+
local_C2 += (float)a3[2] * b26;
1594+
local_C3 += (float)a3[3] * b27;
1595+
local_C0 += (float)a3[4] * b28;
1596+
local_C1 += (float)a3[5] * b29;
1597+
local_C2 += (float)a3[6] * b30;
1598+
local_C3 += (float)a3[7] * b31;
1599+
} else {
1600+
const float* a = reinterpret_cast<const float*>(A_row + inner_idx);
1601+
1602+
local_C0 += a[0] * b0;
1603+
local_C1 += a[1] * b1;
1604+
local_C2 += a[2] * b2;
1605+
local_C3 += a[3] * b3;
1606+
local_C0 += a[4] * b4;
1607+
local_C1 += a[5] * b5;
1608+
local_C2 += a[6] * b6;
1609+
local_C3 += a[7] * b7;
1610+
local_C0 += a[8] * b8;
1611+
local_C1 += a[9] * b9;
1612+
local_C2 += a[10] * b10;
1613+
local_C3 += a[11] * b11;
1614+
local_C0 += a[12] * b12;
1615+
local_C1 += a[13] * b13;
1616+
local_C2 += a[14] * b14;
1617+
local_C3 += a[15] * b15;
1618+
local_C0 += a[16] * b16;
1619+
local_C1 += a[17] * b17;
1620+
local_C2 += a[18] * b18;
1621+
local_C3 += a[19] * b19;
1622+
local_C0 += a[20] * b20;
1623+
local_C1 += a[21] * b21;
1624+
local_C2 += a[22] * b22;
1625+
local_C3 += a[23] * b23;
1626+
local_C0 += a[24] * b24;
1627+
local_C1 += a[25] * b25;
1628+
local_C2 += a[26] * b26;
1629+
local_C3 += a[27] * b27;
1630+
local_C0 += a[28] * b28;
1631+
local_C1 += a[29] * b29;
1632+
local_C2 += a[30] * b30;
1633+
local_C3 += a[31] * b31;
1634+
}
15981635
} else {
15991636
float b_vals[32] = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15,
16001637
b16, b17, b18, b19, b20, b21, b22, b23, b24, b25, b26, b27, b28, b29, b30, b31};

0 commit comments

Comments
 (0)