mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Remove the dependency to blas and use mkl directly. (#125)
This commit is contained in:
@ -11,7 +11,6 @@ license = "MIT/Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
blas = { version = "0.22.0", optional = true }
|
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||||
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
|
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
|
||||||
@ -22,6 +21,7 @@ cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas
|
|||||||
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" }
|
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
|
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
|
||||||
|
libc = { version = "0.2.147", optional = true }
|
||||||
memmap2 = "0.7.1"
|
memmap2 = "0.7.1"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
@ -35,4 +35,4 @@ anyhow = { version = "1", features = ["backtrace"] }
|
|||||||
[features]
|
[features]
|
||||||
default = ["cuda"]
|
default = ["cuda"]
|
||||||
cuda = ["dep:cudarc", "dep:candle-kernels"]
|
cuda = ["dep:cudarc", "dep:candle-kernels"]
|
||||||
mkl = ["dep:blas", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
|
@ -416,6 +416,36 @@ impl Map2 for MatMul {
|
|||||||
|
|
||||||
let mut dst = vec![T::zero(); b * m * n];
|
let mut dst = vec![T::zero(); b * m * n];
|
||||||
match T::DTYPE {
|
match T::DTYPE {
|
||||||
|
DType::F16 => {
|
||||||
|
for step in 0..b {
|
||||||
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
|
unsafe {
|
||||||
|
let a = rhs_p.as_ptr() as *const f16;
|
||||||
|
let b = lhs_p.as_ptr() as *const f16;
|
||||||
|
let c = dst_p.as_mut_ptr() as *mut f16;
|
||||||
|
let a = std::slice::from_raw_parts(a, a_skip);
|
||||||
|
let b = std::slice::from_raw_parts(b, b_skip);
|
||||||
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||||
|
crate::mkl::hgemm(
|
||||||
|
transa,
|
||||||
|
transb,
|
||||||
|
/* m= */ n as i32,
|
||||||
|
/* n= */ m as i32,
|
||||||
|
/* k= */ k as i32,
|
||||||
|
/* alpha= */ f16::ONE,
|
||||||
|
/* a= */ a,
|
||||||
|
/* lda= */ lda,
|
||||||
|
/* b= */ b,
|
||||||
|
/* ldb= */ ldb,
|
||||||
|
/* beta= */ f16::ZERO,
|
||||||
|
/* c= */ c,
|
||||||
|
/* ldc= */ n as i32,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
for step in 0..b {
|
for step in 0..b {
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
@ -428,7 +458,7 @@ impl Map2 for MatMul {
|
|||||||
let a = std::slice::from_raw_parts(a, a_skip);
|
let a = std::slice::from_raw_parts(a, a_skip);
|
||||||
let b = std::slice::from_raw_parts(b, b_skip);
|
let b = std::slice::from_raw_parts(b, b_skip);
|
||||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||||
blas::sgemm(
|
crate::mkl::sgemm(
|
||||||
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||||
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||||
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||||
@ -449,7 +479,7 @@ impl Map2 for MatMul {
|
|||||||
let a = std::slice::from_raw_parts(a, a_skip);
|
let a = std::slice::from_raw_parts(a, a_skip);
|
||||||
let b = std::slice::from_raw_parts(b, b_skip);
|
let b = std::slice::from_raw_parts(b, b_skip);
|
||||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||||
blas::dgemm(
|
crate::mkl::dgemm(
|
||||||
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||||
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||||
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||||
|
@ -44,6 +44,8 @@ mod dtype;
|
|||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
mod error;
|
mod error;
|
||||||
mod layout;
|
mod layout;
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
mod mkl;
|
||||||
mod npy;
|
mod npy;
|
||||||
mod op;
|
mod op;
|
||||||
pub mod safetensors;
|
pub mod safetensors;
|
||||||
|
154
candle-core/src/mkl.rs
Normal file
154
candle-core/src/mkl.rs
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
use libc::{c_char, c_double, c_float, c_int};
|
||||||
|
|
||||||
|
mod ffi {
|
||||||
|
use super::*;
|
||||||
|
extern "C" {
|
||||||
|
pub fn sgemm_(
|
||||||
|
transa: *const c_char,
|
||||||
|
transb: *const c_char,
|
||||||
|
m: *const c_int,
|
||||||
|
n: *const c_int,
|
||||||
|
k: *const c_int,
|
||||||
|
alpha: *const c_float,
|
||||||
|
a: *const c_float,
|
||||||
|
lda: *const c_int,
|
||||||
|
b: *const c_float,
|
||||||
|
ldb: *const c_int,
|
||||||
|
beta: *const c_float,
|
||||||
|
c: *mut c_float,
|
||||||
|
ldc: *const c_int,
|
||||||
|
);
|
||||||
|
pub fn dgemm_(
|
||||||
|
transa: *const c_char,
|
||||||
|
transb: *const c_char,
|
||||||
|
m: *const c_int,
|
||||||
|
n: *const c_int,
|
||||||
|
k: *const c_int,
|
||||||
|
alpha: *const c_double,
|
||||||
|
a: *const c_double,
|
||||||
|
lda: *const c_int,
|
||||||
|
b: *const c_double,
|
||||||
|
ldb: *const c_int,
|
||||||
|
beta: *const c_double,
|
||||||
|
c: *mut c_double,
|
||||||
|
ldc: *const c_int,
|
||||||
|
);
|
||||||
|
pub fn hgemm_(
|
||||||
|
transa: *const c_char,
|
||||||
|
transb: *const c_char,
|
||||||
|
m: *const c_int,
|
||||||
|
n: *const c_int,
|
||||||
|
k: *const c_int,
|
||||||
|
alpha: *const half::f16,
|
||||||
|
a: *const half::f16,
|
||||||
|
lda: *const c_int,
|
||||||
|
b: *const half::f16,
|
||||||
|
ldb: *const c_int,
|
||||||
|
beta: *const half::f16,
|
||||||
|
c: *mut half::f16,
|
||||||
|
ldc: *const c_int,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
#[inline]
|
||||||
|
pub unsafe fn sgemm(
|
||||||
|
transa: u8,
|
||||||
|
transb: u8,
|
||||||
|
m: i32,
|
||||||
|
n: i32,
|
||||||
|
k: i32,
|
||||||
|
alpha: f32,
|
||||||
|
a: &[f32],
|
||||||
|
lda: i32,
|
||||||
|
b: &[f32],
|
||||||
|
ldb: i32,
|
||||||
|
beta: f32,
|
||||||
|
c: &mut [f32],
|
||||||
|
ldc: i32,
|
||||||
|
) {
|
||||||
|
ffi::sgemm_(
|
||||||
|
&(transa as c_char),
|
||||||
|
&(transb as c_char),
|
||||||
|
&m,
|
||||||
|
&n,
|
||||||
|
&k,
|
||||||
|
&alpha,
|
||||||
|
a.as_ptr(),
|
||||||
|
&lda,
|
||||||
|
b.as_ptr(),
|
||||||
|
&ldb,
|
||||||
|
&beta,
|
||||||
|
c.as_mut_ptr(),
|
||||||
|
&ldc,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
#[inline]
|
||||||
|
pub unsafe fn dgemm(
|
||||||
|
transa: u8,
|
||||||
|
transb: u8,
|
||||||
|
m: i32,
|
||||||
|
n: i32,
|
||||||
|
k: i32,
|
||||||
|
alpha: f64,
|
||||||
|
a: &[f64],
|
||||||
|
lda: i32,
|
||||||
|
b: &[f64],
|
||||||
|
ldb: i32,
|
||||||
|
beta: f64,
|
||||||
|
c: &mut [f64],
|
||||||
|
ldc: i32,
|
||||||
|
) {
|
||||||
|
ffi::dgemm_(
|
||||||
|
&(transa as c_char),
|
||||||
|
&(transb as c_char),
|
||||||
|
&m,
|
||||||
|
&n,
|
||||||
|
&k,
|
||||||
|
&alpha,
|
||||||
|
a.as_ptr(),
|
||||||
|
&lda,
|
||||||
|
b.as_ptr(),
|
||||||
|
&ldb,
|
||||||
|
&beta,
|
||||||
|
c.as_mut_ptr(),
|
||||||
|
&ldc,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
#[inline]
|
||||||
|
pub unsafe fn hgemm(
|
||||||
|
transa: u8,
|
||||||
|
transb: u8,
|
||||||
|
m: i32,
|
||||||
|
n: i32,
|
||||||
|
k: i32,
|
||||||
|
alpha: half::f16,
|
||||||
|
a: &[half::f16],
|
||||||
|
lda: i32,
|
||||||
|
b: &[half::f16],
|
||||||
|
ldb: i32,
|
||||||
|
beta: half::f16,
|
||||||
|
c: &mut [half::f16],
|
||||||
|
ldc: i32,
|
||||||
|
) {
|
||||||
|
ffi::hgemm_(
|
||||||
|
&(transa as c_char),
|
||||||
|
&(transb as c_char),
|
||||||
|
&m,
|
||||||
|
&n,
|
||||||
|
&k,
|
||||||
|
&alpha,
|
||||||
|
a.as_ptr(),
|
||||||
|
&lda,
|
||||||
|
b.as_ptr(),
|
||||||
|
&ldb,
|
||||||
|
&beta,
|
||||||
|
c.as_mut_ptr(),
|
||||||
|
&ldc,
|
||||||
|
)
|
||||||
|
}
|
Reference in New Issue
Block a user