mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add mkl support for matrix multiply. (#86)
* Fix some rebase issues. * Use mkl instead. * Use mkl in bert. * Add the optional mkl feature. * Conditional compilation based on the mkl feature. * Add more mkl support.
This commit is contained in:
@ -11,6 +11,7 @@ license = "MIT/Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
blas = { version = "0.22.0", optional = true }
|
||||
byteorder = "1.4.3"
|
||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
||||
@ -18,6 +19,7 @@ cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
||||
# https://github.com/sarah-ek/gemm/pull/8 is available.
|
||||
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
|
||||
memmap2 = "0.7.1"
|
||||
num-traits = "0.2.15"
|
||||
num_cpus = "1.15.0"
|
||||
@ -31,3 +33,4 @@ anyhow = { version = "1", features = ["backtrace"] }
|
||||
[features]
|
||||
default = ["cuda"]
|
||||
cuda = ["dep:cudarc", "dep:candle-kernels"]
|
||||
mkl = ["dep:blas", "dep:intel-mkl-src"]
|
||||
|
@ -1,3 +1,6 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
|
||||
|
@ -1,3 +1,6 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
use crate::op::{BinaryOp, UnaryOp};
|
||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||
use gemm::{gemm, Parallelism};
|
||||
use half::{bf16, f16};
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
@ -264,6 +263,8 @@ struct MatMul((usize, usize, usize, usize));
|
||||
|
||||
impl Map2 for MatMul {
|
||||
const OP: &'static str = "mat_mul";
|
||||
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
fn f<T: 'static + num_traits::Num + Copy>(
|
||||
&self,
|
||||
lhs: &[T],
|
||||
@ -271,6 +272,7 @@ impl Map2 for MatMul {
|
||||
rhs: &[T],
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
use gemm::{gemm, Parallelism};
|
||||
let (b, m, n, k) = self.0;
|
||||
let lhs = &lhs[lhs_l.start_offset()..];
|
||||
let rhs = &rhs[rhs_l.start_offset()..];
|
||||
@ -346,6 +348,98 @@ impl Map2 for MatMul {
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
fn f<T: 'static + num_traits::Num + Copy>(
|
||||
&self,
|
||||
lhs: &[T],
|
||||
lhs_l: &Layout,
|
||||
rhs: &[T],
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
let (b, m, n, k) = self.0;
|
||||
let lhs = &lhs[lhs_l.start_offset()..];
|
||||
let rhs = &rhs[rhs_l.start_offset()..];
|
||||
let cfg = crate::cuda_backend::gemm_config(1f32, 0f32, (b, m, n, k), lhs_l, rhs_l)?;
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(Error::UnexpectedStriding {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
})?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(Error::UnexpectedStriding {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
})?,
|
||||
};
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let mut dst = vec![T::zero(); b * m * n];
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
let gemm = cfg.gemm;
|
||||
let a = rhs_p.as_ptr() as *const f32;
|
||||
let b = lhs_p.as_ptr() as *const f32;
|
||||
let c = dst_p.as_mut_ptr() as *mut f32;
|
||||
let a = std::slice::from_raw_parts(a, a_skip);
|
||||
let b = std::slice::from_raw_parts(b, b_skip);
|
||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||
let transa = match gemm.transa {
|
||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N',
|
||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T',
|
||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C',
|
||||
_ => todo!(),
|
||||
};
|
||||
let transb = match gemm.transb {
|
||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N',
|
||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T',
|
||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C',
|
||||
_ => todo!(),
|
||||
};
|
||||
blas::sgemm(
|
||||
transa, transb, gemm.m, gemm.n, gemm.k, gemm.alpha, a, gemm.lda, b, gemm.ldb,
|
||||
gemm.beta, c, gemm.ldc,
|
||||
)
|
||||
// gemm(
|
||||
// /* m: usize = */ m,
|
||||
// /* n: usize = */ n,
|
||||
// /* k: usize = */ k,
|
||||
// /* dst: *mut T = */ dst_p.as_mut_ptr(),
|
||||
// /* dst_cs: isize = */ dst_cs as isize,
|
||||
// /* dst_rs: isize = */ dst_rs as isize,
|
||||
// /* read_dst: bool = */ false,
|
||||
// /* lhs: *const T = */ lhs_p.as_ptr(),
|
||||
// /* lhs_cs: isize = */ lhs_cs as isize,
|
||||
// /* lhs_rs: isize = */ lhs_rs as isize,
|
||||
// /* rhs: *const T = */ rhs_p.as_ptr(),
|
||||
// /* rhs_cs: isize = */ rhs_cs as isize,
|
||||
// /* rhs_rs: isize = */ rhs_rs as isize,
|
||||
// /* alpha: T = */ T::zero(),
|
||||
// /* beta: T = */ T::one(),
|
||||
// /* conj_dst: bool = */ false,
|
||||
// /* conj_lhs: bool = */ false,
|
||||
// /* conj_rhs: bool = */ false,
|
||||
// parallelism,
|
||||
// )
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
|
||||
|
@ -543,7 +543,7 @@ pub struct CudaStorage {
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
fn gemm_config<T>(
|
||||
pub(crate) fn gemm_config<T>(
|
||||
alpha: T,
|
||||
beta: T,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
|
@ -15,6 +15,7 @@ candle = { path = "../candle-core", default-features=false }
|
||||
serde = { version = "1.0.166", features = ["derive"] }
|
||||
serde_json = "1.0.99"
|
||||
num-traits = "0.2.15"
|
||||
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
@ -28,3 +29,4 @@ wav = "1.0.0"
|
||||
[features]
|
||||
default = ["cuda"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
|
@ -1,4 +1,8 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_hub::{api::Api, Cache, Repo, RepoType};
|
||||
|
@ -9,6 +9,9 @@
|
||||
// In order to convert the llama weights to a .npz file, run:
|
||||
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
// TODO: This does not use a batch dimension. If adding it back, be cautious about the
|
||||
// transposition operations.
|
||||
use anyhow::{Error as E, Result};
|
||||
@ -24,7 +27,7 @@ mod var_store;
|
||||
mod weights;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DTYPE: DType = DType::F16;
|
||||
const DTYPE: DType = DType::F32;
|
||||
const DEFAULT_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
|
@ -6,6 +6,9 @@
|
||||
// - Batch size greater than 1.
|
||||
// - More token filters (SuppressBlanks, ApplyTimestampRules).
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_hub::{api::Api, Repo, RepoType};
|
||||
|
Reference in New Issue
Block a user