Better handling of the batch dimension in matmul.

This commit is contained in:
laurent
2023-07-03 22:51:40 +01:00
parent ee3f7c0269
commit 86d691c74c
2 changed files with 39 additions and 20 deletions

View File

@ -597,11 +597,31 @@ fn gemm_config<T>(
transa,
transb,
};
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
[stride] => stride,
[] => m * k,
_ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
};
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
[stride] => stride,
[] => n * k,
_ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
};
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
stride_a: (n * k) as i64,
stride_b: (m * k) as i64,
stride_a: stride_a as i64,
stride_b: stride_b as i64,
stride_c: (m * n) as i64,
})
}