mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
More flexible matmul contiguity checks. (#1949)
* More flexible matmul contiguity checks. * Also relax the checks on the metal side.
This commit is contained in:
@ -1651,9 +1651,11 @@ fn gemm_config<T>(
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
// The a tensor has dims batching, k, n (rhs)
|
// The a tensor has dims batching, k, n (rhs)
|
||||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||||
|
// there is a single element.
|
||||||
|
let (lda, transa) = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
|
||||||
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
|
||||||
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||||
} else {
|
} else {
|
||||||
Err(CudaError::MatMulNonContiguous {
|
Err(CudaError::MatMulNonContiguous {
|
||||||
@ -1663,9 +1665,11 @@ fn gemm_config<T>(
|
|||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||||
|
// there is a single element.
|
||||||
|
let (ldb, transb) = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
|
||||||
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
|
||||||
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||||
} else {
|
} else {
|
||||||
Err(CudaError::MatMulNonContiguous {
|
Err(CudaError::MatMulNonContiguous {
|
||||||
|
@ -2007,6 +2007,16 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a tensor that is in row major order. This always makes a copy.
|
||||||
|
pub fn force_contiguous(&self) -> Result<Tensor> {
|
||||||
|
let shape = self.shape();
|
||||||
|
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||||
|
self.storage()
|
||||||
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
|
let op = BackpropOp::new1(self, Op::Copy);
|
||||||
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||||
/// copied.
|
/// copied.
|
||||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||||
|
@ -1135,6 +1135,30 @@ fn randn(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/candle/issues/1948
|
||||||
|
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||||
|
let seq_len = 8_usize;
|
||||||
|
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
||||||
|
let x = a.i((.., seq_len - 1, ..))?;
|
||||||
|
println!(
|
||||||
|
"x shape:{:?}, stride:{:?}, is_contiguous:{}",
|
||||||
|
x.shape(),
|
||||||
|
x.stride(),
|
||||||
|
x.is_contiguous()
|
||||||
|
);
|
||||||
|
|
||||||
|
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
||||||
|
println!(
|
||||||
|
"w shape:{:?}, stride:{:?}, is_contiguous:{}",
|
||||||
|
w.shape(),
|
||||||
|
w.stride(),
|
||||||
|
w.is_contiguous()
|
||||||
|
);
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
|
assert_eq!(x.dims(), &[1, 32]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||||
@ -1190,6 +1214,7 @@ test_device!(
|
|||||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||||
|
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
|
@ -1451,9 +1451,12 @@ pub fn call_gemm(
|
|||||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
|
// lhs has shape b, m, k
|
||||||
|
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||||
|
// there is a single element.
|
||||||
|
let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
|
||||||
false
|
false
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
|
||||||
true
|
true
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
@ -1462,9 +1465,10 @@ pub fn call_gemm(
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?;
|
})?;
|
||||||
};
|
};
|
||||||
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
|
// rhs has shape b, k, n
|
||||||
|
let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
|
||||||
false
|
false
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
|
||||||
true
|
true
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
|
Reference in New Issue
Block a user