mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Matmul cublas support for f16.
This commit is contained in:
@ -788,8 +788,15 @@ impl CudaStorage {
|
|||||||
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
|
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
|
||||||
todo!("bf16")
|
todo!("bf16")
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F16(_lhs), CudaStorageSlice::F16(_rhs)) => {
|
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||||
todo!("f16")
|
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||||
|
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||||
|
unsafe {
|
||||||
|
self.device
|
||||||
|
.blas
|
||||||
|
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||||
|
}?;
|
||||||
|
CudaStorageSlice::F16(out)
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||||
|
Reference in New Issue
Block a user