Add a batch dimension on the bert example.

This commit is contained in:
laurent
2023-07-04 06:10:52 +01:00
parent 8e4d298c90
commit a57b314780
4 changed files with 32 additions and 15 deletions

View File

@ -227,14 +227,22 @@ impl Map2 for MatMul {
let rhs_rs = rhs_stride[rank - 2];
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(Error::UnexpectedStriding)?,
_ => Err(Error::UnexpectedStriding {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
})?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(Error::UnexpectedStriding)?,
_ => Err(Error::UnexpectedStriding {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
})?,
};
let c_skip: usize = m * n;