diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index dbb89eaf..39b41d2e 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1890,14 +1890,24 @@ unsafe fn gemm_strided_batched_f16( use cudarc::cublas::sys; use cudarc::driver::DevicePtrMut; - let compute_type = if gemm_reduced_precision_f16() { - sys::cublasComputeType_t::CUBLAS_COMPUTE_16F + let alpha = cfg.gemm.alpha; + let beta = cfg.gemm.beta; + let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); + let beta_f32: f32 = cfg.gemm.beta.to_f32(); + let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() { + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_16F, + (&alpha) as *const f16 as *const _, + (&beta) as *const f16 as *const _, + ) } else { - sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + (&alpha_f32) as *const f32 as *const _, + (&beta_f32) as *const f32 as *const _, + ) }; - let alpha: f32 = cfg.gemm.alpha.to_f32(); - let beta: f32 = cfg.gemm.beta.to_f32(); cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, @@ -1905,7 +1915,7 @@ unsafe fn gemm_strided_batched_f16( cfg.gemm.m, cfg.gemm.n, cfg.gemm.k, - (&alpha) as *const f32 as *const _, + alpha, *a.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.lda, @@ -1914,7 +1924,7 @@ unsafe fn gemm_strided_batched_f16( sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldb, cfg.stride_b, - (&beta) as *const f32 as *const _, + beta, *c.device_ptr_mut() as *mut _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldc, @@ -1935,14 +1945,26 @@ unsafe fn gemm_strided_batched_bf16( use cudarc::cublas::sys; use cudarc::driver::DevicePtrMut; - let compute_type = if gemm_reduced_precision_bf16() { - sys::cublasComputeType_t::CUBLAS_COMPUTE_16F + let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); + let beta_f32: f32 = cfg.gemm.beta.to_f32(); + let alpha = f16::from_f32(alpha_f32); + let beta = f16::from_f32(beta_f32); + // The type for alpha and beta depends on the computeType. + // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex + let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() { + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_16F, + (&alpha) as *const f16 as *const _, + (&beta) as *const f16 as *const _, + ) } else { - sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + (&alpha_f32) as *const f32 as *const _, + (&beta_f32) as *const f32 as *const _, + ) }; - let alpha: f32 = cfg.gemm.alpha.to_f32(); - let beta: f32 = cfg.gemm.beta.to_f32(); cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, @@ -1950,7 +1972,7 @@ unsafe fn gemm_strided_batched_bf16( cfg.gemm.m, cfg.gemm.n, cfg.gemm.k, - (&alpha) as *const f32 as *const _, + alpha, *a.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.lda, @@ -1959,7 +1981,7 @@ unsafe fn gemm_strided_batched_bf16( sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldb, cfg.stride_b, - (&beta) as *const f32 as *const _, + beta, *c.device_ptr_mut() as *mut _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldc,