mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Properly handle the batch dimension in cuda quantized matmul. (#1832)
This commit is contained in:
@ -313,7 +313,7 @@ impl QCudaStorage {
|
||||
}
|
||||
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0);
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
||||
let mut out_shape = layout.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
|
Reference in New Issue
Block a user