mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Matmul cublas support for f16.
This commit is contained in:
@ -788,8 +788,15 @@ impl CudaStorage {
|
||||
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
|
||||
todo!("bf16")
|
||||
}
|
||||
(CudaStorageSlice::F16(_lhs), CudaStorageSlice::F16(_rhs)) => {
|
||||
todo!("f16")
|
||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
|
Reference in New Issue
Block a user