Add a batch dimension on the bert example.

This commit is contained in:
laurent
2023-07-04 06:10:52 +01:00
parent 8e4d298c90
commit a57b314780
4 changed files with 32 additions and 15 deletions

View File

@ -599,6 +599,7 @@ fn gemm_config<T>(
};
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(CudaError::MatMulNonContiguous {
@ -608,6 +609,7 @@ fn gemm_config<T>(
})?,
};
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(CudaError::MatMulNonContiguous {