Properly handle the batch dimension in cuda quantized matmul. (#1832)

This commit is contained in:
Laurent Mazare
2024-03-10 20:23:43 +01:00
committed by GitHub
parent 0c5eecbc0f
commit df5f69444e

View File

@ -313,7 +313,7 @@ impl QCudaStorage {
} }
let data_f32 = self.dequantize(n * k)?; let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0); let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?; let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let mut out_shape = layout.shape().dims().to_vec(); let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop(); out_shape.pop();