From c297a5096029edcc69117d3f1a7b97e7a5fc6767 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 6 Jul 2023 11:05:05 +0100 Subject: [PATCH] Add mkl support for matrix multiply. (#86) * Fix some rebase issues. * Use mkl instead. * Use mkl in bert. * Add the optional mkl feature. * Conditional compilation based on the mkl feature. * Add more mkl support. --- candle-core/Cargo.toml | 3 + candle-core/examples/basics.rs | 3 + candle-core/examples/cuda_basics.rs | 3 + candle-core/src/cpu_backend.rs | 96 +++++++++++++++++++++++- candle-core/src/cuda_backend.rs | 2 +- candle-examples/Cargo.toml | 2 + candle-examples/examples/bert/main.rs | 4 + candle-examples/examples/llama/main.rs | 5 +- candle-examples/examples/whisper/main.rs | 3 + 9 files changed, 118 insertions(+), 3 deletions(-) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 85f77af5..529f9812 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -11,6 +11,7 @@ license = "MIT/Apache-2.0" readme = "README.md" [dependencies] +blas = { version = "0.22.0", optional = true } byteorder = "1.4.3" candle-kernels = { path = "../candle-kernels", optional = true } cudarc = { version = "0.9.9", optional = true, features = ["f16"] } @@ -18,6 +19,7 @@ cudarc = { version = "0.9.9", optional = true, features = ["f16"] } # https://github.com/sarah-ek/gemm/pull/8 is available. gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" } half = { version = "2.3.1", features = ["num-traits"] } +intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]} memmap2 = "0.7.1" num-traits = "0.2.15" num_cpus = "1.15.0" @@ -31,3 +33,4 @@ anyhow = { version = "1", features = ["backtrace"] } [features] default = ["cuda"] cuda = ["dep:cudarc", "dep:candle-kernels"] +mkl = ["dep:blas", "dep:intel-mkl-src"] diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs index 733a7fe0..a5e2b24e 100644 --- a/candle-core/examples/basics.rs +++ b/candle-core/examples/basics.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + use anyhow::Result; use candle::{Device, Tensor}; diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index c8852bf6..6050d793 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + use anyhow::Result; use candle::{Device, Tensor}; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index b2345756..7ccadb44 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,6 +1,5 @@ use crate::op::{BinaryOp, UnaryOp}; use crate::{DType, Error, Layout, Result, Shape, WithDType}; -use gemm::{gemm, Parallelism}; use half::{bf16, f16}; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + @@ -264,6 +263,8 @@ struct MatMul((usize, usize, usize, usize)); impl Map2 for MatMul { const OP: &'static str = "mat_mul"; + + #[cfg(not(feature = "mkl"))] fn f( &self, lhs: &[T], @@ -271,6 +272,7 @@ impl Map2 for MatMul { rhs: &[T], rhs_l: &Layout, ) -> Result> { + use gemm::{gemm, Parallelism}; let (b, m, n, k) = self.0; let lhs = &lhs[lhs_l.start_offset()..]; let rhs = &rhs[rhs_l.start_offset()..]; @@ -346,6 +348,98 @@ impl Map2 for MatMul { } Ok(dst) } + + #[cfg(feature = "mkl")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + ) -> Result> { + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + let cfg = crate::cuda_backend::gemm_config(1f32, 0f32, (b, m, n, k), lhs_l, rhs_l)?; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + + let a_skip: usize = match lhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, + [stride] => stride, + [] => m * k, + _ => Err(Error::UnexpectedStriding { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + })?, + }; + let b_skip: usize = match rhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, + [stride] => stride, + [] => n * k, + _ => Err(Error::UnexpectedStriding { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + })?, + }; + let c_skip: usize = m * n; + + let mut dst = vec![T::zero(); b * m * n]; + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let gemm = cfg.gemm; + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + let transa = match gemm.transa { + cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N', + cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T', + cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C', + _ => todo!(), + }; + let transb = match gemm.transb { + cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N', + cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T', + cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C', + _ => todo!(), + }; + blas::sgemm( + transa, transb, gemm.m, gemm.n, gemm.k, gemm.alpha, a, gemm.lda, b, gemm.ldb, + gemm.beta, c, gemm.ldc, + ) + // gemm( + // /* m: usize = */ m, + // /* n: usize = */ n, + // /* k: usize = */ k, + // /* dst: *mut T = */ dst_p.as_mut_ptr(), + // /* dst_cs: isize = */ dst_cs as isize, + // /* dst_rs: isize = */ dst_rs as isize, + // /* read_dst: bool = */ false, + // /* lhs: *const T = */ lhs_p.as_ptr(), + // /* lhs_cs: isize = */ lhs_cs as isize, + // /* lhs_rs: isize = */ lhs_rs as isize, + // /* rhs: *const T = */ rhs_p.as_ptr(), + // /* rhs_cs: isize = */ rhs_cs as isize, + // /* rhs_rs: isize = */ rhs_rs as isize, + // /* alpha: T = */ T::zero(), + // /* beta: T = */ T::one(), + // /* conj_dst: bool = */ false, + // /* conj_lhs: bool = */ false, + // /* conj_rhs: bool = */ false, + // parallelism, + // ) + } + } + Ok(dst) + } } fn divide_by_sum_over_dim(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 917655fc..927a5944 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -543,7 +543,7 @@ pub struct CudaStorage { device: CudaDevice, } -fn gemm_config( +pub(crate) fn gemm_config( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index a39ee3a3..98cad54f 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -15,6 +15,7 @@ candle = { path = "../candle-core", default-features=false } serde = { version = "1.0.166", features = ["derive"] } serde_json = "1.0.99" num-traits = "0.2.15" +intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]} [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } @@ -28,3 +29,4 @@ wav = "1.0.0" [features] default = ["cuda"] cuda = ["candle/cuda"] +mkl = ["dep:intel-mkl-src", "candle/mkl"] diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 11d01a6a..bf99b1bf 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,4 +1,8 @@ #![allow(dead_code)] + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + use anyhow::{anyhow, Error as E, Result}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use candle_hub::{api::Api, Cache, Repo, RepoType}; diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 9f87b59a..fbb5e03c 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -9,6 +9,9 @@ // In order to convert the llama weights to a .npz file, run: // python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + // TODO: This does not use a batch dimension. If adding it back, be cautious about the // transposition operations. use anyhow::{Error as E, Result}; @@ -24,7 +27,7 @@ mod var_store; mod weights; const MAX_SEQ_LEN: usize = 4096; -const DTYPE: DType = DType::F16; +const DTYPE: DType = DType::F32; const DEFAULT_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 7679f1a2..0c9d0893 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -6,6 +6,9 @@ // - Batch size greater than 1. // - More token filters (SuppressBlanks, ApplyTimestampRules). +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; use candle_hub::{api::Api, Repo, RepoType};