diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index f0f03053..97dc346e 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1651,9 +1651,11 @@ fn gemm_config( let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; // The a tensor has dims batching, k, n (rhs) - let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (lda, transa) = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) { (n as i32, cublasOperation_t::CUBLAS_OP_N) - } else if rhs_m1 == k && rhs_m2 == 1 { + } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 { (k as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { @@ -1663,9 +1665,11 @@ fn gemm_config( })? }; // The b tensor has dims batching, m, k (lhs) - let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (ldb, transb) = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) { (k as i32, cublasOperation_t::CUBLAS_OP_N) - } else if lhs_m1 == m && lhs_m2 == 1 { + } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 { (m as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 92c931eb..b53b0419 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2007,6 +2007,16 @@ impl Tensor { } } + /// Returns a tensor that is in row major order. This always makes a copy. + pub fn force_contiguous(&self) -> Result { + let shape = self.shape(); + let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let op = BackpropOp::new1(self, Op::Copy); + Ok(from_storage(storage, shape.clone(), op, false)) + } + /// Create a variable based on the values currently stored in a tensor. The storage is always /// copied. pub(crate) fn make_var(&self) -> Result { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index b2475adc..af28c1c1 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1135,6 +1135,30 @@ fn randn(device: &Device) -> Result<()> { Ok(()) } +// https://github.com/huggingface/candle/issues/1948 +fn squeeze_mm(device: &Device) -> Result<()> { + let seq_len = 8_usize; + let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?; + let x = a.i((.., seq_len - 1, ..))?; + println!( + "x shape:{:?}, stride:{:?}, is_contiguous:{}", + x.shape(), + x.stride(), + x.is_contiguous() + ); + + let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?; + println!( + "w shape:{:?}, stride:{:?}, is_contiguous:{}", + w.shape(), + w.stride(), + w.is_contiguous() + ); + let x = x.matmul(&w)?; + assert_eq!(x.dims(), &[1, 32]); + Ok(()) +} + test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); test_device!(ones, ones_cpu, ones_gpu, ones_metal); test_device!(full, full_cpu, full_gpu, full_metal); @@ -1190,6 +1214,7 @@ test_device!( test_device!(randn, randn_cpu, randn_gpu, randn_metal); test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal); test_device!(var, var_cpu, var_gpu, var_metal); +test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal); // There was originally a bug on the CPU implementation for randn // https://github.com/huggingface/candle/issues/381 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 449bef8f..3f452331 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1451,9 +1451,12 @@ pub fn call_gemm( let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - let a_trans = if lhs_m1 == 1 && lhs_m2 == k { + // lhs has shape b, m, k + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) { false - } else if lhs_m1 == m && lhs_m2 == 1 { + } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 { true } else { return Err(MetalKernelError::MatMulNonContiguous { @@ -1462,9 +1465,10 @@ pub fn call_gemm( mnk: (m, n, k), })?; }; - let b_trans = if rhs_m1 == 1 && rhs_m2 == n { + // rhs has shape b, k, n + let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) { false - } else if rhs_m1 == k && rhs_m2 == 1 { + } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 { true } else { return Err(MetalKernelError::MatMulNonContiguous {