mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add a batch dimension on the bert example.
This commit is contained in:
@ -227,14 +227,22 @@ impl Map2 for MatMul {
|
||||
let rhs_rs = rhs_stride[rank - 2];
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(Error::UnexpectedStriding)?,
|
||||
_ => Err(Error::UnexpectedStriding {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
})?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(Error::UnexpectedStriding)?,
|
||||
_ => Err(Error::UnexpectedStriding {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
})?,
|
||||
};
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
|
@ -599,6 +599,7 @@ fn gemm_config<T>(
|
||||
};
|
||||
|
||||
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
@ -608,6 +609,7 @@ fn gemm_config<T>(
|
||||
})?,
|
||||
};
|
||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
|
@ -93,8 +93,11 @@ pub enum Error {
|
||||
},
|
||||
|
||||
// TODO this is temporary when we support arbitrary matmul
|
||||
#[error("temporary error where matmul doesn't support arbitrary striding")]
|
||||
UnexpectedStriding,
|
||||
#[error("temporary error where matmul doesn't support arbitrary striding {lhs_stride:?} x {rhs_stride:?}")]
|
||||
UnexpectedStriding {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] crate::CudaError),
|
||||
|
Reference in New Issue
Block a user