mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Use cublas bf16. (#101)
This commit is contained in:
@ -14,7 +14,8 @@ readme = "README.md"
|
|||||||
blas = { version = "0.22.0", optional = true }
|
blas = { version = "0.22.0", optional = true }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||||
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
# cudarc = { version = "0.9.12", optional = true, features = ["f16"] }
|
||||||
|
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] }
|
||||||
# TODO: Switch back to the official gemm implementation once something similar to
|
# TODO: Switch back to the official gemm implementation once something similar to
|
||||||
# https://github.com/sarah-ek/gemm/pull/8 is available.
|
# https://github.com/sarah-ek/gemm/pull/8 is available.
|
||||||
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" }
|
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" }
|
||||||
|
@ -827,8 +827,17 @@ impl CudaStorage {
|
|||||||
let elem_count = b * m * n;
|
let elem_count = b * m * n;
|
||||||
let dev = &self.device;
|
let dev = &self.device;
|
||||||
let slice = match (&self.slice, &rhs.slice) {
|
let slice = match (&self.slice, &rhs.slice) {
|
||||||
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
|
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
|
||||||
todo!("bf16")
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
|
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||||
|
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
||||||
|
unsafe {
|
||||||
|
self.device
|
||||||
|
.blas
|
||||||
|
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||||
|
}?;
|
||||||
|
CudaStorageSlice::BF16(out)
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
|
Reference in New Issue
Block a user