Bugfix the recent f16/bf16 changes. (#2142)

This commit is contained in:
Laurent Mazare
2024-04-29 13:30:11 +02:00
committed by GitHub
parent a0d03aded1
commit 09d4845aa8

View File

@ -1896,8 +1896,8 @@ unsafe fn gemm_strided_batched_f16(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
let alpha = cfg.gemm.alpha;
let beta = cfg.gemm.beta;
let alpha: f32 = cfg.gemm.alpha.to_f32();
let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@ -1905,7 +1905,7 @@ unsafe fn gemm_strided_batched_f16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const f16 as *const _,
(&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
@ -1914,7 +1914,7 @@ unsafe fn gemm_strided_batched_f16(
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const f16 as *const _,
(&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
@ -1941,8 +1941,8 @@ unsafe fn gemm_strided_batched_bf16(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
let alpha = cfg.gemm.alpha;
let beta = cfg.gemm.beta;
let alpha: f32 = cfg.gemm.alpha.to_f32();
let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@ -1950,7 +1950,7 @@ unsafe fn gemm_strided_batched_bf16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const bf16 as *const _,
(&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.lda,
@ -1959,7 +1959,7 @@ unsafe fn gemm_strided_batched_bf16(
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const bf16 as *const _,
(&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldc,