mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Use cublas bf16. (#101)
This commit is contained in:
@ -14,7 +14,8 @@ readme = "README.md"
|
||||
blas = { version = "0.22.0", optional = true }
|
||||
byteorder = "1.4.3"
|
||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
||||
# cudarc = { version = "0.9.12", optional = true, features = ["f16"] }
|
||||
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation once something similar to
|
||||
# https://github.com/sarah-ek/gemm/pull/8 is available.
|
||||
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" }
|
||||
|
@ -827,8 +827,17 @@ impl CudaStorage {
|
||||
let elem_count = b * m * n;
|
||||
let dev = &self.device;
|
||||
let slice = match (&self.slice, &rhs.slice) {
|
||||
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
|
||||
todo!("bf16")
|
||||
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
|
Reference in New Issue
Block a user