@@ -210,37 +210,28 @@ def forward(
210210 # 2. Quantize B
211211 state .CB , state .SCB , _ = F .int8_vectorwise_quant (B .to (torch .float16 ))
212212
213- # Handle sparse decomposition. In some instances, we may have not found any
214- # outlier columns at all. In that case, we'll skip this part completely.
215- if state .threshold > 0.0 and outlier_cols is not None and outlier_cols .numel ():
213+ # Handle sparse decomposition
214+ if state .threshold > 0.0 :
216215 state .idx = outlier_cols
217216
218- # Zero out the outliers in the transposed 8bit inputs.
219- if CAt is not None :
220- CAt [:, state .idx ] = 0
221-
222- # Extract the input outliers in original precision
223- subA = A [:, state .idx ].contiguous ()
217+ # Mixed Int8 Matmul + Dequant + Bias
218+ output , subA = torch .ops .bitsandbytes .int8_mixed_scaled_mm (
219+ A ,
220+ CA ,
221+ state .CB ,
222+ SCA ,
223+ state .SCB ,
224+ outlier_cols ,
225+ bias ,
226+ )
224227
225- # Extract the corresponding weights
226- if state .has_fp16_weights :
227- state .subB = B [:, state .idx ].t ()
228- else :
229- # To dequantize our weights associated with the input outliers,
230- # we want to divide by 127. It's however more performant to multiply
231- # by the reciprocal.
232- outliers = state .CB [:, state .idx ]
233- state .subB = F .int8_vectorwise_dequant (outliers , state .SCB ).to (A .dtype ).t ()
234228 else :
229+ # Int8 Matmul + Dequant + Bias
230+ output = torch .ops .bitsandbytes .int8_scaled_mm .default (
231+ CA , state .CB , SCA , state .SCB , bias = bias , dtype = A .dtype
232+ )
235233 subA = None
236234
237- # 3. Int8 Matmul + Dequant + Bias
238- output = torch .ops .bitsandbytes .int8_scaled_mm .default (CA , state .CB , SCA , state .SCB , bias = bias , dtype = A .dtype )
239-
240- # 4. Mixed-precision decomposition matmul
241- if subA is not None and state .subB is not None :
242- output = output .addmm (subA , state .subB )
243-
244235 # 5. Save state
245236 ctx .state = state
246237
0 commit comments