You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
based on @Jamezo97 and @acpopescu work
manually cherry-picked from PR #788 and PR #229 and cleanup by wkpark
Signed-off-by: Won-Kyu Park <wkpark@gmail.com>
Copy file name to clipboardExpand all lines: csrc/kernels.cu
+6-6Lines changed: 6 additions & 6 deletions
Original file line number
Diff line number
Diff line change
@@ -3816,12 +3816,12 @@ template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N
3816
3816
template __global__voidkExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
3817
3817
template __global__voidkExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
3818
3818
3819
-
template __global__void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3820
-
template __global__void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3821
-
template __global__void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3822
-
template __global__void kspmm_coo_very_sparse_naive<signedchar, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signedchar *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3823
-
template __global__void kspmm_coo_very_sparse_naive<signedchar, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signedchar *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3824
-
template __global__void kspmm_coo_very_sparse_naive<signedchar, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signedchar *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3819
+
template __global__void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *__restrict__constdequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3820
+
template __global__void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *__restrict__constdequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3821
+
template __global__void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *__restrict__constdequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3822
+
template __global__void kspmm_coo_very_sparse_naive<signedchar, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signedchar *B, half *out, float *__restrict__constdequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3823
+
template __global__void kspmm_coo_very_sparse_naive<signedchar, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signedchar *B, half *out, float *__restrict__constdequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3824
+
template __global__void kspmm_coo_very_sparse_naive<signedchar, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signedchar *B, half *out, float *__restrict__constdequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3825
3825
3826
3826
template __global__voidkTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3827
3827
template __global__voidkTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
0 commit comments