Better handling of the batch dimension in matmul.

This commit is contained in:
laurent
2023-07-03 22:51:40 +01:00
parent ee3f7c0269
commit 86d691c74c
2 changed files with 39 additions and 20 deletions

View File

@ -216,9 +216,6 @@ impl Map2 for MatMul {
let (b, m, n, k) = self.0;
let lhs = &lhs[lhs_l.start_offset()..];
let rhs = &rhs[rhs_l.start_offset()..];
let a_skip: usize = m * k;
let b_skip: usize = n * k;
let c_skip: usize = m * n;
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
@ -229,15 +226,17 @@ impl Map2 for MatMul {
let rhs_cs = rhs_stride[rank - 1];
let rhs_rs = rhs_stride[rank - 2];
if lhs_stride.len() > 2 {
let lhs_batch_stride = &lhs_stride[..rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2];
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding);
}
}
let a_skip: usize = match lhs_stride[..rank - 2] {
[stride] => stride,
[] => m * k,
_ => Err(Error::UnexpectedStriding)?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[stride] => stride,
[] => n * k,
_ => Err(Error::UnexpectedStriding)?,
};
let c_skip: usize = m * n;
let dst_shape: Shape = (m, n).into();
let dst_strides = dst_shape.stride_contiguous();
@ -245,16 +244,16 @@ impl Map2 for MatMul {
let dst_cs = dst_strides[1];
let mut dst = vec![T::zero(); b * m * n];
let num_threads = crate::utils::get_num_threads();
let parallelism = if num_threads > 1 {
Parallelism::Rayon(num_threads)
} else {
Parallelism::None
};
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
let num_threads = crate::utils::get_num_threads();
let parallelism = if num_threads > 1 {
Parallelism::Rayon(num_threads)
} else {
Parallelism::None
};
unsafe {
gemm(
/* m: usize = */ m,