@@ -1553,48 +1553,85 @@ __global__ void kgemm_4bit_inference_naive(
15531553 float b31 = quant_map[qm_offset + (local_B_4bit[15 ] & 0xF )] * local_absmax;
15541554
15551555 if (__builtin_expect (inner_idx + 32 <= K, 1 )) {
1556- int4 a_vec0 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 ];
1557- int4 a_vec1 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 + 1 ];
1558- int4 a_vec2 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 + 2 ];
1559- int4 a_vec3 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 + 3 ];
1560-
1561- const T* a0 = reinterpret_cast <const T*>(&a_vec0);
1562- const T* a1 = reinterpret_cast <const T*>(&a_vec1);
1563- const T* a2 = reinterpret_cast <const T*>(&a_vec2);
1564- const T* a3 = reinterpret_cast <const T*>(&a_vec3);
1565-
1566- local_C0 += (float )a0[0 ] * b0;
1567- local_C1 += (float )a0[1 ] * b1;
1568- local_C2 += (float )a0[2 ] * b2;
1569- local_C3 += (float )a0[3 ] * b3;
1570- local_C0 += (float )a0[4 ] * b4;
1571- local_C1 += (float )a0[5 ] * b5;
1572- local_C2 += (float )a0[6 ] * b6;
1573- local_C3 += (float )a0[7 ] * b7;
1574- local_C0 += (float )a1[0 ] * b8;
1575- local_C1 += (float )a1[1 ] * b9;
1576- local_C2 += (float )a1[2 ] * b10;
1577- local_C3 += (float )a1[3 ] * b11;
1578- local_C0 += (float )a1[4 ] * b12;
1579- local_C1 += (float )a1[5 ] * b13;
1580- local_C2 += (float )a1[6 ] * b14;
1581- local_C3 += (float )a1[7 ] * b15;
1582- local_C0 += (float )a2[0 ] * b16;
1583- local_C1 += (float )a2[1 ] * b17;
1584- local_C2 += (float )a2[2 ] * b18;
1585- local_C3 += (float )a2[3 ] * b19;
1586- local_C0 += (float )a2[4 ] * b20;
1587- local_C1 += (float )a2[5 ] * b21;
1588- local_C2 += (float )a2[6 ] * b22;
1589- local_C3 += (float )a2[7 ] * b23;
1590- local_C0 += (float )a3[0 ] * b24;
1591- local_C1 += (float )a3[1 ] * b25;
1592- local_C2 += (float )a3[2 ] * b26;
1593- local_C3 += (float )a3[3 ] * b27;
1594- local_C0 += (float )a3[4 ] * b28;
1595- local_C1 += (float )a3[5 ] * b29;
1596- local_C2 += (float )a3[6 ] * b30;
1597- local_C3 += (float )a3[7 ] * b31;
1556+ if constexpr (BITS == 16 ) {
1557+ int4 a_vec0 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 ];
1558+ int4 a_vec1 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 + 1 ];
1559+ int4 a_vec2 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 + 2 ];
1560+ int4 a_vec3 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 + 3 ];
1561+
1562+ const T* a0 = reinterpret_cast <const T*>(&a_vec0);
1563+ const T* a1 = reinterpret_cast <const T*>(&a_vec1);
1564+ const T* a2 = reinterpret_cast <const T*>(&a_vec2);
1565+ const T* a3 = reinterpret_cast <const T*>(&a_vec3);
1566+
1567+ local_C0 += (float )a0[0 ] * b0;
1568+ local_C1 += (float )a0[1 ] * b1;
1569+ local_C2 += (float )a0[2 ] * b2;
1570+ local_C3 += (float )a0[3 ] * b3;
1571+ local_C0 += (float )a0[4 ] * b4;
1572+ local_C1 += (float )a0[5 ] * b5;
1573+ local_C2 += (float )a0[6 ] * b6;
1574+ local_C3 += (float )a0[7 ] * b7;
1575+ local_C0 += (float )a1[0 ] * b8;
1576+ local_C1 += (float )a1[1 ] * b9;
1577+ local_C2 += (float )a1[2 ] * b10;
1578+ local_C3 += (float )a1[3 ] * b11;
1579+ local_C0 += (float )a1[4 ] * b12;
1580+ local_C1 += (float )a1[5 ] * b13;
1581+ local_C2 += (float )a1[6 ] * b14;
1582+ local_C3 += (float )a1[7 ] * b15;
1583+ local_C0 += (float )a2[0 ] * b16;
1584+ local_C1 += (float )a2[1 ] * b17;
1585+ local_C2 += (float )a2[2 ] * b18;
1586+ local_C3 += (float )a2[3 ] * b19;
1587+ local_C0 += (float )a2[4 ] * b20;
1588+ local_C1 += (float )a2[5 ] * b21;
1589+ local_C2 += (float )a2[6 ] * b22;
1590+ local_C3 += (float )a2[7 ] * b23;
1591+ local_C0 += (float )a3[0 ] * b24;
1592+ local_C1 += (float )a3[1 ] * b25;
1593+ local_C2 += (float )a3[2 ] * b26;
1594+ local_C3 += (float )a3[3 ] * b27;
1595+ local_C0 += (float )a3[4 ] * b28;
1596+ local_C1 += (float )a3[5 ] * b29;
1597+ local_C2 += (float )a3[6 ] * b30;
1598+ local_C3 += (float )a3[7 ] * b31;
1599+ } else {
1600+ const float * a = reinterpret_cast <const float *>(A_row + inner_idx);
1601+
1602+ local_C0 += a[0 ] * b0;
1603+ local_C1 += a[1 ] * b1;
1604+ local_C2 += a[2 ] * b2;
1605+ local_C3 += a[3 ] * b3;
1606+ local_C0 += a[4 ] * b4;
1607+ local_C1 += a[5 ] * b5;
1608+ local_C2 += a[6 ] * b6;
1609+ local_C3 += a[7 ] * b7;
1610+ local_C0 += a[8 ] * b8;
1611+ local_C1 += a[9 ] * b9;
1612+ local_C2 += a[10 ] * b10;
1613+ local_C3 += a[11 ] * b11;
1614+ local_C0 += a[12 ] * b12;
1615+ local_C1 += a[13 ] * b13;
1616+ local_C2 += a[14 ] * b14;
1617+ local_C3 += a[15 ] * b15;
1618+ local_C0 += a[16 ] * b16;
1619+ local_C1 += a[17 ] * b17;
1620+ local_C2 += a[18 ] * b18;
1621+ local_C3 += a[19 ] * b19;
1622+ local_C0 += a[20 ] * b20;
1623+ local_C1 += a[21 ] * b21;
1624+ local_C2 += a[22 ] * b22;
1625+ local_C3 += a[23 ] * b23;
1626+ local_C0 += a[24 ] * b24;
1627+ local_C1 += a[25 ] * b25;
1628+ local_C2 += a[26 ] * b26;
1629+ local_C3 += a[27 ] * b27;
1630+ local_C0 += a[28 ] * b28;
1631+ local_C1 += a[29 ] * b29;
1632+ local_C2 += a[30 ] * b30;
1633+ local_C3 += a[31 ] * b31;
1634+ }
15981635 } else {
15991636 float b_vals[32 ] = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15,
16001637 b16, b17, b18, b19, b20, b21, b22, b23, b24, b25, b26, b27, b28, b29, b30, b31};
0 commit comments