mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Better handling of the batch dimension in matmul.
This commit is contained in:
@ -216,9 +216,6 @@ impl Map2 for MatMul {
|
|||||||
let (b, m, n, k) = self.0;
|
let (b, m, n, k) = self.0;
|
||||||
let lhs = &lhs[lhs_l.start_offset()..];
|
let lhs = &lhs[lhs_l.start_offset()..];
|
||||||
let rhs = &rhs[rhs_l.start_offset()..];
|
let rhs = &rhs[rhs_l.start_offset()..];
|
||||||
let a_skip: usize = m * k;
|
|
||||||
let b_skip: usize = n * k;
|
|
||||||
let c_skip: usize = m * n;
|
|
||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
@ -229,15 +226,17 @@ impl Map2 for MatMul {
|
|||||||
let rhs_cs = rhs_stride[rank - 1];
|
let rhs_cs = rhs_stride[rank - 1];
|
||||||
let rhs_rs = rhs_stride[rank - 2];
|
let rhs_rs = rhs_stride[rank - 2];
|
||||||
|
|
||||||
if lhs_stride.len() > 2 {
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
[stride] => stride,
|
||||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
[] => m * k,
|
||||||
|
_ => Err(Error::UnexpectedStriding)?,
|
||||||
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
};
|
||||||
// Temporary error before we support abitrary striding.
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
return Err(Error::UnexpectedStriding);
|
[stride] => stride,
|
||||||
}
|
[] => n * k,
|
||||||
}
|
_ => Err(Error::UnexpectedStriding)?,
|
||||||
|
};
|
||||||
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
let dst_shape: Shape = (m, n).into();
|
||||||
let dst_strides = dst_shape.stride_contiguous();
|
let dst_strides = dst_shape.stride_contiguous();
|
||||||
@ -245,16 +244,16 @@ impl Map2 for MatMul {
|
|||||||
let dst_cs = dst_strides[1];
|
let dst_cs = dst_strides[1];
|
||||||
|
|
||||||
let mut dst = vec![T::zero(); b * m * n];
|
let mut dst = vec![T::zero(); b * m * n];
|
||||||
for step in 0..b {
|
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
|
||||||
let num_threads = crate::utils::get_num_threads();
|
let num_threads = crate::utils::get_num_threads();
|
||||||
let parallelism = if num_threads > 1 {
|
let parallelism = if num_threads > 1 {
|
||||||
Parallelism::Rayon(num_threads)
|
Parallelism::Rayon(num_threads)
|
||||||
} else {
|
} else {
|
||||||
Parallelism::None
|
Parallelism::None
|
||||||
};
|
};
|
||||||
|
for step in 0..b {
|
||||||
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
unsafe {
|
unsafe {
|
||||||
gemm(
|
gemm(
|
||||||
/* m: usize = */ m,
|
/* m: usize = */ m,
|
||||||
|
@ -597,11 +597,31 @@ fn gemm_config<T>(
|
|||||||
transa,
|
transa,
|
||||||
transb,
|
transb,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(CudaError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(CudaError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
|
||||||
Ok(StridedBatchedConfig {
|
Ok(StridedBatchedConfig {
|
||||||
batch_size: b as i32,
|
batch_size: b as i32,
|
||||||
gemm,
|
gemm,
|
||||||
stride_a: (n * k) as i64,
|
stride_a: stride_a as i64,
|
||||||
stride_b: (m * k) as i64,
|
stride_b: stride_b as i64,
|
||||||
stride_c: (m * n) as i64,
|
stride_c: (m * n) as i64,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user