mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add a batch dimension on the bert example.
This commit is contained in:
@ -227,14 +227,22 @@ impl Map2 for MatMul {
|
||||
let rhs_rs = rhs_stride[rank - 2];
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(Error::UnexpectedStriding)?,
|
||||
_ => Err(Error::UnexpectedStriding {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
})?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(Error::UnexpectedStriding)?,
|
||||
_ => Err(Error::UnexpectedStriding {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
})?,
|
||||
};
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
|
Reference in New Issue
Block a user