mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
More flexible matmul contiguity checks. (#1949)
* More flexible matmul contiguity checks. * Also relax the checks on the metal side.
This commit is contained in:
@ -1651,9 +1651,11 @@ fn gemm_config<T>(
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// The a tensor has dims batching, k, n (rhs)
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (lda, transa) = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
|
||||
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
@ -1663,9 +1665,11 @@ fn gemm_config<T>(
|
||||
})?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (ldb, transb) = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
|
||||
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
|
@ -2007,6 +2007,16 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a tensor that is in row major order. This always makes a copy.
|
||||
pub fn force_contiguous(&self) -> Result<Tensor> {
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let op = BackpropOp::new1(self, Op::Copy);
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
|
||||
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||
/// copied.
|
||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user