Matmul cublas support for f16.

This commit is contained in:
laurent
2023-06-26 22:08:22 +01:00
parent 36a4749e95
commit a6a7477bea

View File

@ -788,8 +788,15 @@ impl CudaStorage {
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
todo!("bf16")
}
(CudaStorageSlice::F16(_lhs), CudaStorageSlice::F16(_rhs)) => {
todo!("f16")
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;