mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Properly handle the batch dimension in cuda quantized matmul. (#1832)
This commit is contained in:
@ -313,7 +313,7 @@ impl QCudaStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let data_f32 = self.dequantize(n * k)?;
|
let data_f32 = self.dequantize(n * k)?;
|
||||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0);
|
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
||||||
let mut out_shape = layout.shape().dims().to_vec();
|
let mut out_shape = layout.shape().dims().to_vec();
|
||||||
out_shape.pop();
|
out_shape.pop();
|
||||||
|
Reference in New Issue
Block a user