From 09d4845aa842dc5d9da650fd7865c4f0855dcf97 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 29 Apr 2024 13:30:11 +0200 Subject: [PATCH] Bugfix the recent f16/bf16 changes. (#2142) --- candle-core/src/cuda_backend/mod.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 67ed56e0..dbb89eaf 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1896,8 +1896,8 @@ unsafe fn gemm_strided_batched_f16( sys::cublasComputeType_t::CUBLAS_COMPUTE_32F }; - let alpha = cfg.gemm.alpha; - let beta = cfg.gemm.beta; + let alpha: f32 = cfg.gemm.alpha.to_f32(); + let beta: f32 = cfg.gemm.beta.to_f32(); cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, @@ -1905,7 +1905,7 @@ unsafe fn gemm_strided_batched_f16( cfg.gemm.m, cfg.gemm.n, cfg.gemm.k, - (&alpha) as *const f16 as *const _, + (&alpha) as *const f32 as *const _, *a.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.lda, @@ -1914,7 +1914,7 @@ unsafe fn gemm_strided_batched_f16( sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldb, cfg.stride_b, - (&beta) as *const f16 as *const _, + (&beta) as *const f32 as *const _, *c.device_ptr_mut() as *mut _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldc, @@ -1941,8 +1941,8 @@ unsafe fn gemm_strided_batched_bf16( sys::cublasComputeType_t::CUBLAS_COMPUTE_32F }; - let alpha = cfg.gemm.alpha; - let beta = cfg.gemm.beta; + let alpha: f32 = cfg.gemm.alpha.to_f32(); + let beta: f32 = cfg.gemm.beta.to_f32(); cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, @@ -1950,7 +1950,7 @@ unsafe fn gemm_strided_batched_bf16( cfg.gemm.m, cfg.gemm.n, cfg.gemm.k, - (&alpha) as *const bf16 as *const _, + (&alpha) as *const f32 as *const _, *a.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.lda, @@ -1959,7 +1959,7 @@ unsafe fn gemm_strided_batched_bf16( sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldb, cfg.stride_b, - (&beta) as *const bf16 as *const _, + (&beta) as *const f32 as *const _, *c.device_ptr_mut() as *mut _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldc,