Use cublas bf16. (#101)

This commit is contained in:
Laurent Mazare
2023-07-07 08:00:12 +01:00
committed by GitHub
parent c71a38deb7
commit 02b5c38049
2 changed files with 13 additions and 3 deletions

View File

@ -14,7 +14,8 @@ readme = "README.md"
blas = { version = "0.22.0", optional = true } blas = { version = "0.22.0", optional = true }
byteorder = "1.4.3" byteorder = "1.4.3"
candle-kernels = { path = "../candle-kernels", optional = true } candle-kernels = { path = "../candle-kernels", optional = true }
cudarc = { version = "0.9.9", optional = true, features = ["f16"] } # cudarc = { version = "0.9.12", optional = true, features = ["f16"] }
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] }
# TODO: Switch back to the official gemm implementation once something similar to # TODO: Switch back to the official gemm implementation once something similar to
# https://github.com/sarah-ek/gemm/pull/8 is available. # https://github.com/sarah-ek/gemm/pull/8 is available.
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" } gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" }

View File

@ -827,8 +827,17 @@ impl CudaStorage {
let elem_count = b * m * n; let elem_count = b * m * n;
let dev = &self.device; let dev = &self.device;
let slice = match (&self.slice, &rhs.slice) { let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => { (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
todo!("bf16") let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?;
CudaStorageSlice::BF16(out)
} }
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..); let lhs = &lhs.slice(lhs_l.start_offset()..);