Add a batch dimension on the bert example.

This commit is contained in:
laurent
2023-07-04 06:10:52 +01:00
parent 8e4d298c90
commit a57b314780
4 changed files with 32 additions and 15 deletions

View File

@ -93,8 +93,11 @@ pub enum Error {
},
// TODO this is temporary when we support arbitrary matmul
#[error("temporary error where matmul doesn't support arbitrary striding")]
UnexpectedStriding,
#[error("temporary error where matmul doesn't support arbitrary striding {lhs_stride:?} x {rhs_stride:?}")]
UnexpectedStriding {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
},
#[error(transparent)]
Cuda(#[from] crate::CudaError),