mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Bugfix the recent f16/bf16 changes. (#2142)
This commit is contained in:
@ -1896,8 +1896,8 @@ unsafe fn gemm_strided_batched_f16(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||
};
|
||||
|
||||
let alpha = cfg.gemm.alpha;
|
||||
let beta = cfg.gemm.beta;
|
||||
let alpha: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta: f32 = cfg.gemm.beta.to_f32();
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
@ -1905,7 +1905,7 @@ unsafe fn gemm_strided_batched_f16(
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
(&alpha) as *const f16 as *const _,
|
||||
(&alpha) as *const f32 as *const _,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.lda,
|
||||
@ -1914,7 +1914,7 @@ unsafe fn gemm_strided_batched_f16(
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
(&beta) as *const f16 as *const _,
|
||||
(&beta) as *const f32 as *const _,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.ldc,
|
||||
@ -1941,8 +1941,8 @@ unsafe fn gemm_strided_batched_bf16(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||
};
|
||||
|
||||
let alpha = cfg.gemm.alpha;
|
||||
let beta = cfg.gemm.beta;
|
||||
let alpha: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta: f32 = cfg.gemm.beta.to_f32();
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
@ -1950,7 +1950,7 @@ unsafe fn gemm_strided_batched_bf16(
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
(&alpha) as *const bf16 as *const _,
|
||||
(&alpha) as *const f32 as *const _,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.lda,
|
||||
@ -1959,7 +1959,7 @@ unsafe fn gemm_strided_batched_bf16(
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
(&beta) as *const bf16 as *const _,
|
||||
(&beta) as *const f32 as *const _,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.ldc,
|
||||
|
Reference in New Issue
Block a user