mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Better handling of the batch dimension in matmul.
This commit is contained in:
@ -216,9 +216,6 @@ impl Map2 for MatMul {
|
||||
let (b, m, n, k) = self.0;
|
||||
let lhs = &lhs[lhs_l.start_offset()..];
|
||||
let rhs = &rhs[rhs_l.start_offset()..];
|
||||
let a_skip: usize = m * k;
|
||||
let b_skip: usize = n * k;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
@ -229,15 +226,17 @@ impl Map2 for MatMul {
|
||||
let rhs_cs = rhs_stride[rank - 1];
|
||||
let rhs_rs = rhs_stride[rank - 2];
|
||||
|
||||
if lhs_stride.len() > 2 {
|
||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||
|
||||
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
||||
// Temporary error before we support abitrary striding.
|
||||
return Err(Error::UnexpectedStriding);
|
||||
}
|
||||
}
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(Error::UnexpectedStriding)?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(Error::UnexpectedStriding)?,
|
||||
};
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let dst_shape: Shape = (m, n).into();
|
||||
let dst_strides = dst_shape.stride_contiguous();
|
||||
@ -245,16 +244,16 @@ impl Map2 for MatMul {
|
||||
let dst_cs = dst_strides[1];
|
||||
|
||||
let mut dst = vec![T::zero(); b * m * n];
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
let parallelism = if num_threads > 1 {
|
||||
Parallelism::Rayon(num_threads)
|
||||
} else {
|
||||
Parallelism::None
|
||||
};
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
let parallelism = if num_threads > 1 {
|
||||
Parallelism::Rayon(num_threads)
|
||||
} else {
|
||||
Parallelism::None
|
||||
};
|
||||
unsafe {
|
||||
gemm(
|
||||
/* m: usize = */ m,
|
||||
|
@ -597,11 +597,31 @@ fn gemm_config<T>(
|
||||
transa,
|
||||
transb,
|
||||
};
|
||||
|
||||
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
|
||||
Ok(StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
stride_a: (n * k) as i64,
|
||||
stride_b: (m * k) as i64,
|
||||
stride_a: stride_a as i64,
|
||||
stride_b: stride_b as i64,
|
||||
stride_c: (m * n) as i64,
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user