mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
MKL adjustments. (#87)
This commit is contained in:
@ -360,7 +360,6 @@ impl Map2 for MatMul {
|
|||||||
let (b, m, n, k) = self.0;
|
let (b, m, n, k) = self.0;
|
||||||
let lhs = &lhs[lhs_l.start_offset()..];
|
let lhs = &lhs[lhs_l.start_offset()..];
|
||||||
let rhs = &rhs[rhs_l.start_offset()..];
|
let rhs = &rhs[rhs_l.start_offset()..];
|
||||||
let cfg = crate::cuda_backend::gemm_config(1f32, 0f32, (b, m, n, k), lhs_l, rhs_l)?;
|
|
||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
@ -386,56 +385,53 @@ impl Map2 for MatMul {
|
|||||||
};
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
|
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
|
(n as i32, b'N')
|
||||||
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
|
(k as i32, b'T')
|
||||||
|
} else {
|
||||||
|
Err(Error::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
|
(k as i32, b'N')
|
||||||
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
|
(m as i32, b'T')
|
||||||
|
} else {
|
||||||
|
Err(Error::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
|
||||||
let mut dst = vec![T::zero(); b * m * n];
|
let mut dst = vec![T::zero(); b * m * n];
|
||||||
for step in 0..b {
|
for step in 0..b {
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
unsafe {
|
unsafe {
|
||||||
let gemm = cfg.gemm;
|
|
||||||
let a = rhs_p.as_ptr() as *const f32;
|
let a = rhs_p.as_ptr() as *const f32;
|
||||||
let b = lhs_p.as_ptr() as *const f32;
|
let b = lhs_p.as_ptr() as *const f32;
|
||||||
let c = dst_p.as_mut_ptr() as *mut f32;
|
let c = dst_p.as_mut_ptr() as *mut f32;
|
||||||
let a = std::slice::from_raw_parts(a, a_skip);
|
let a = std::slice::from_raw_parts(a, a_skip);
|
||||||
let b = std::slice::from_raw_parts(b, b_skip);
|
let b = std::slice::from_raw_parts(b, b_skip);
|
||||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||||
let transa = match gemm.transa {
|
|
||||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N',
|
|
||||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T',
|
|
||||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C',
|
|
||||||
_ => todo!(),
|
|
||||||
};
|
|
||||||
let transb = match gemm.transb {
|
|
||||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N',
|
|
||||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T',
|
|
||||||
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C',
|
|
||||||
_ => todo!(),
|
|
||||||
};
|
|
||||||
blas::sgemm(
|
blas::sgemm(
|
||||||
transa, transb, gemm.m, gemm.n, gemm.k, gemm.alpha, a, gemm.lda, b, gemm.ldb,
|
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||||
gemm.beta, c, gemm.ldc,
|
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||||
|
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb, /* beta= */ 0.,
|
||||||
|
/* c= */ c, /* ldc= */ n as i32,
|
||||||
)
|
)
|
||||||
// gemm(
|
|
||||||
// /* m: usize = */ m,
|
|
||||||
// /* n: usize = */ n,
|
|
||||||
// /* k: usize = */ k,
|
|
||||||
// /* dst: *mut T = */ dst_p.as_mut_ptr(),
|
|
||||||
// /* dst_cs: isize = */ dst_cs as isize,
|
|
||||||
// /* dst_rs: isize = */ dst_rs as isize,
|
|
||||||
// /* read_dst: bool = */ false,
|
|
||||||
// /* lhs: *const T = */ lhs_p.as_ptr(),
|
|
||||||
// /* lhs_cs: isize = */ lhs_cs as isize,
|
|
||||||
// /* lhs_rs: isize = */ lhs_rs as isize,
|
|
||||||
// /* rhs: *const T = */ rhs_p.as_ptr(),
|
|
||||||
// /* rhs_cs: isize = */ rhs_cs as isize,
|
|
||||||
// /* rhs_rs: isize = */ rhs_rs as isize,
|
|
||||||
// /* alpha: T = */ T::zero(),
|
|
||||||
// /* beta: T = */ T::one(),
|
|
||||||
// /* conj_dst: bool = */ false,
|
|
||||||
// /* conj_lhs: bool = */ false,
|
|
||||||
// /* conj_rhs: bool = */ false,
|
|
||||||
// parallelism,
|
|
||||||
// )
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
|
@ -543,7 +543,7 @@ pub struct CudaStorage {
|
|||||||
device: CudaDevice,
|
device: CudaDevice,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn gemm_config<T>(
|
fn gemm_config<T>(
|
||||||
alpha: T,
|
alpha: T,
|
||||||
beta: T,
|
beta: T,
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
@ -129,6 +129,13 @@ pub enum Error {
|
|||||||
|
|
||||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||||
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||||
|
|
||||||
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
|
MatMulNonContiguous {
|
||||||
|
lhs_stride: Vec<usize>,
|
||||||
|
rhs_stride: Vec<usize>,
|
||||||
|
mnk: (usize, usize, usize),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
@ -27,7 +27,10 @@ mod var_store;
|
|||||||
mod weights;
|
mod weights;
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
#[cfg(not(feature = "mkl"))]
|
||||||
|
const DTYPE: DType = DType::F16;
|
||||||
const DEFAULT_PROMPT: &str = r"
|
const DEFAULT_PROMPT: &str = r"
|
||||||
EDWARD:
|
EDWARD:
|
||||||
I wonder how our princely father 'scaped,
|
I wonder how our princely father 'scaped,
|
||||||
|
Reference in New Issue
Block a user