diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index d9141f0f..3dfc5c8f 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -216,9 +216,6 @@ impl Map2 for MatMul { let (b, m, n, k) = self.0; let lhs = &lhs[lhs_l.start_offset()..]; let rhs = &rhs[rhs_l.start_offset()..]; - let a_skip: usize = m * k; - let b_skip: usize = n * k; - let c_skip: usize = m * n; let lhs_stride = lhs_l.stride(); let rhs_stride = rhs_l.stride(); @@ -229,15 +226,17 @@ impl Map2 for MatMul { let rhs_cs = rhs_stride[rank - 1]; let rhs_rs = rhs_stride[rank - 2]; - if lhs_stride.len() > 2 { - let lhs_batch_stride = &lhs_stride[..rank - 2]; - let rhs_batch_stride = &rhs_stride[..rank - 2]; - - if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] { - // Temporary error before we support abitrary striding. - return Err(Error::UnexpectedStriding); - } - } + let a_skip: usize = match lhs_stride[..rank - 2] { + [stride] => stride, + [] => m * k, + _ => Err(Error::UnexpectedStriding)?, + }; + let b_skip: usize = match rhs_stride[..rank - 2] { + [stride] => stride, + [] => n * k, + _ => Err(Error::UnexpectedStriding)?, + }; + let c_skip: usize = m * n; let dst_shape: Shape = (m, n).into(); let dst_strides = dst_shape.stride_contiguous(); @@ -245,16 +244,16 @@ impl Map2 for MatMul { let dst_cs = dst_strides[1]; let mut dst = vec![T::zero(); b * m * n]; + let num_threads = crate::utils::get_num_threads(); + let parallelism = if num_threads > 1 { + Parallelism::Rayon(num_threads) + } else { + Parallelism::None + }; for step in 0..b { let lhs_p = &lhs[step * a_skip..]; let rhs_p = &rhs[step * b_skip..]; let dst_p = &mut dst[step * c_skip..]; - let num_threads = crate::utils::get_num_threads(); - let parallelism = if num_threads > 1 { - Parallelism::Rayon(num_threads) - } else { - Parallelism::None - }; unsafe { gemm( /* m: usize = */ m, diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 641efd7f..0b58787e 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -597,11 +597,31 @@ fn gemm_config( transa, transb, }; + + let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { + [stride] => stride, + [] => m * k, + _ => Err(CudaError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?, + }; + let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { + [stride] => stride, + [] => n * k, + _ => Err(CudaError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?, + }; + Ok(StridedBatchedConfig { batch_size: b as i32, gemm, - stride_a: (n * k) as i64, - stride_b: (m * k) as i64, + stride_a: stride_a as i64, + stride_b: stride_b as i64, stride_c: (m * n) as i64, }) }