More flexible matmul contiguity checks. (#1949)

* More flexible matmul contiguity checks.

* Also relax the checks on the metal side.
This commit is contained in:
Laurent Mazare
2024-03-27 10:59:05 +01:00
committed by GitHub
parent 75b6d4b0da
commit a9abde5f93
4 changed files with 51 additions and 8 deletions

View File

@ -1651,9 +1651,11 @@ fn gemm_config<T>(
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, k, n (rhs)
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (lda, transa) = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
(n as i32, cublasOperation_t::CUBLAS_OP_N)
} else if rhs_m1 == k && rhs_m2 == 1 {
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
(k as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
@ -1663,9 +1665,11 @@ fn gemm_config<T>(
})?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (ldb, transb) = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
(k as i32, cublasOperation_t::CUBLAS_OP_N)
} else if lhs_m1 == m && lhs_m2 == 1 {
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
(m as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {

View File

@ -2007,6 +2007,16 @@ impl Tensor {
}
}
/// Returns a tensor that is in row major order. This always makes a copy.
pub fn force_contiguous(&self) -> Result<Tensor> {
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let op = BackpropOp::new1(self, Op::Copy);
Ok(from_storage(storage, shape.clone(), op, false))
}
/// Create a variable based on the values currently stored in a tensor. The storage is always
/// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> {

View File

@ -1135,6 +1135,30 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
// https://github.com/huggingface/candle/issues/1948
fn squeeze_mm(device: &Device) -> Result<()> {
let seq_len = 8_usize;
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
let x = a.i((.., seq_len - 1, ..))?;
println!(
"x shape:{:?}, stride:{:?}, is_contiguous:{}",
x.shape(),
x.stride(),
x.is_contiguous()
);
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
println!(
"w shape:{:?}, stride:{:?}, is_contiguous:{}",
w.shape(),
w.stride(),
w.is_contiguous()
);
let x = x.matmul(&w)?;
assert_eq!(x.dims(), &[1, 32]);
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
@ -1190,6 +1214,7 @@ test_device!(
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381

View File

@ -1451,9 +1451,12 @@ pub fn call_gemm(
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
false
} else if lhs_m1 == m && lhs_m2 == 1 {
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
@ -1462,9 +1465,10 @@ pub fn call_gemm(
mnk: (m, n, k),
})?;
};
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
// rhs has shape b, k, n
let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
false
} else if rhs_m1 == k && rhs_m2 == 1 {
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {