Better handling of the batch dimension in matmul.

This commit is contained in:
laurent
2023-07-03 22:51:40 +01:00
parent ee3f7c0269
commit 86d691c74c
2 changed files with 39 additions and 20 deletions

View File

@ -216,9 +216,6 @@ impl Map2 for MatMul {
let (b, m, n, k) = self.0;
let lhs = &lhs[lhs_l.start_offset()..];
let rhs = &rhs[rhs_l.start_offset()..];
let a_skip: usize = m * k;
let b_skip: usize = n * k;
let c_skip: usize = m * n;
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
@ -229,15 +226,17 @@ impl Map2 for MatMul {
let rhs_cs = rhs_stride[rank - 1];
let rhs_rs = rhs_stride[rank - 2];
if lhs_stride.len() > 2 {
let lhs_batch_stride = &lhs_stride[..rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2];
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding);
}
}
let a_skip: usize = match lhs_stride[..rank - 2] {
[stride] => stride,
[] => m * k,
_ => Err(Error::UnexpectedStriding)?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[stride] => stride,
[] => n * k,
_ => Err(Error::UnexpectedStriding)?,
};
let c_skip: usize = m * n;
let dst_shape: Shape = (m, n).into();
let dst_strides = dst_shape.stride_contiguous();
@ -245,16 +244,16 @@ impl Map2 for MatMul {
let dst_cs = dst_strides[1];
let mut dst = vec![T::zero(); b * m * n];
let num_threads = crate::utils::get_num_threads();
let parallelism = if num_threads > 1 {
Parallelism::Rayon(num_threads)
} else {
Parallelism::None
};
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
let num_threads = crate::utils::get_num_threads();
let parallelism = if num_threads > 1 {
Parallelism::Rayon(num_threads)
} else {
Parallelism::None
};
unsafe {
gemm(
/* m: usize = */ m,

View File

@ -597,11 +597,31 @@ fn gemm_config<T>(
transa,
transb,
};
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
[stride] => stride,
[] => m * k,
_ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
};
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
[stride] => stride,
[] => n * k,
_ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
};
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
stride_a: (n * k) as i64,
stride_b: (m * k) as i64,
stride_a: stride_a as i64,
stride_b: stride_b as i64,
stride_c: (m * n) as i64,
})
}