mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add a batch dimension on the bert example.
This commit is contained in:
@ -599,6 +599,7 @@ fn gemm_config<T>(
|
||||
};
|
||||
|
||||
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
@ -608,6 +609,7 @@ fn gemm_config<T>(
|
||||
})?,
|
||||
};
|
||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
|
Reference in New Issue
Block a user