From a6a7477bea8721332b160d571894a480e7298376 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 22:08:22 +0100 Subject: [PATCH] Matmul cublas support for f16. --- src/cuda_backend.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index d5be8bf6..12790125 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -788,8 +788,15 @@ impl CudaStorage { (CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => { todo!("bf16") } - (CudaStorageSlice::F16(_lhs), CudaStorageSlice::F16(_rhs)) => { - todo!("f16") + (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { + let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?; + let mut out = unsafe { dev.alloc::(elem_count) }?; + unsafe { + self.device + .blas + .gemm_strided_batched(cfg, rhs, lhs, &mut out) + }?; + CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;