From 02b5c38049cd5d2564a901db618a6eaec0a661d6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 7 Jul 2023 08:00:12 +0100 Subject: [PATCH] Use cublas bf16. (#101) --- candle-core/Cargo.toml | 3 ++- candle-core/src/cuda_backend.rs | 13 +++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 529f9812..943f1953 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -14,7 +14,8 @@ readme = "README.md" blas = { version = "0.22.0", optional = true } byteorder = "1.4.3" candle-kernels = { path = "../candle-kernels", optional = true } -cudarc = { version = "0.9.9", optional = true, features = ["f16"] } +# cudarc = { version = "0.9.12", optional = true, features = ["f16"] } +cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", optional = true, features = ["f16"] } # TODO: Switch back to the official gemm implementation once something similar to # https://github.com/sarah-ek/gemm/pull/8 is available. gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 917655fc..b1990b8f 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -827,8 +827,17 @@ impl CudaStorage { let elem_count = b * m * n; let dev = &self.device; let slice = match (&self.slice, &rhs.slice) { - (CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => { - todo!("bf16") + (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; + let mut out = unsafe { dev.alloc::(elem_count) }?; + unsafe { + self.device + .blas + .gemm_strided_batched(cfg, rhs, lhs, &mut out) + }?; + CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..);