diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 7ccadb44..6d1cea3b 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -360,7 +360,6 @@ impl Map2 for MatMul { let (b, m, n, k) = self.0; let lhs = &lhs[lhs_l.start_offset()..]; let rhs = &rhs[rhs_l.start_offset()..]; - let cfg = crate::cuda_backend::gemm_config(1f32, 0f32, (b, m, n, k), lhs_l, rhs_l)?; let lhs_stride = lhs_l.stride(); let rhs_stride = rhs_l.stride(); @@ -386,56 +385,53 @@ impl Map2 for MatMul { }; let c_skip: usize = m * n; + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(Error::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(Error::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; + let mut dst = vec![T::zero(); b * m * n]; for step in 0..b { let lhs_p = &lhs[step * a_skip..]; let rhs_p = &rhs[step * b_skip..]; let dst_p = &mut dst[step * c_skip..]; unsafe { - let gemm = cfg.gemm; let a = rhs_p.as_ptr() as *const f32; let b = lhs_p.as_ptr() as *const f32; let c = dst_p.as_mut_ptr() as *mut f32; let a = std::slice::from_raw_parts(a, a_skip); let b = std::slice::from_raw_parts(b, b_skip); let c = std::slice::from_raw_parts_mut(c, c_skip); - let transa = match gemm.transa { - cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N', - cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T', - cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C', - _ => todo!(), - }; - let transb = match gemm.transb { - cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N', - cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T', - cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C', - _ => todo!(), - }; blas::sgemm( - transa, transb, gemm.m, gemm.n, gemm.k, gemm.alpha, a, gemm.lda, b, gemm.ldb, - gemm.beta, c, gemm.ldc, + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, /* beta= */ 0., + /* c= */ c, /* ldc= */ n as i32, ) - // gemm( - // /* m: usize = */ m, - // /* n: usize = */ n, - // /* k: usize = */ k, - // /* dst: *mut T = */ dst_p.as_mut_ptr(), - // /* dst_cs: isize = */ dst_cs as isize, - // /* dst_rs: isize = */ dst_rs as isize, - // /* read_dst: bool = */ false, - // /* lhs: *const T = */ lhs_p.as_ptr(), - // /* lhs_cs: isize = */ lhs_cs as isize, - // /* lhs_rs: isize = */ lhs_rs as isize, - // /* rhs: *const T = */ rhs_p.as_ptr(), - // /* rhs_cs: isize = */ rhs_cs as isize, - // /* rhs_rs: isize = */ rhs_rs as isize, - // /* alpha: T = */ T::zero(), - // /* beta: T = */ T::one(), - // /* conj_dst: bool = */ false, - // /* conj_lhs: bool = */ false, - // /* conj_rhs: bool = */ false, - // parallelism, - // ) } } Ok(dst) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 927a5944..917655fc 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -543,7 +543,7 @@ pub struct CudaStorage { device: CudaDevice, } -pub(crate) fn gemm_config( +fn gemm_config( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 7a2d2984..9f12b9a2 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -129,6 +129,13 @@ pub enum Error { #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, + + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec, + rhs_stride: Vec, + mnk: (usize, usize, usize), + }, } pub type Result = std::result::Result; diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index fbb5e03c..1fba7bbd 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -27,7 +27,10 @@ mod var_store; mod weights; const MAX_SEQ_LEN: usize = 4096; +#[cfg(feature = "mkl")] const DTYPE: DType = DType::F32; +#[cfg(not(feature = "mkl"))] +const DTYPE: DType = DType::F16; const DEFAULT_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped,