mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Better handling of the batch dimension in matmul.
This commit is contained in:
@ -597,11 +597,31 @@ fn gemm_config<T>(
|
||||
transa,
|
||||
transb,
|
||||
};
|
||||
|
||||
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
|
||||
Ok(StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
stride_a: (n * k) as i64,
|
||||
stride_b: (m * k) as i64,
|
||||
stride_a: stride_a as i64,
|
||||
stride_b: stride_b as i64,
|
||||
stride_c: (m * n) as i64,
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user