F16/BF16 bugfix (bis). (#2143)

* F16/BF16 bugfix (bis).

* Another fix.

* Yet another fix.
This commit is contained in:
Laurent Mazare
2024-04-29 14:08:44 +02:00
committed by GitHub
parent 09d4845aa8
commit fa06f5f5f9

View File

@ -1890,14 +1890,24 @@ unsafe fn gemm_strided_batched_f16(
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;
let compute_type = if gemm_reduced_precision_f16() {
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
let alpha = cfg.gemm.alpha;
let beta = cfg.gemm.beta;
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
let beta_f32: f32 = cfg.gemm.beta.to_f32();
let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() {
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
(&alpha) as *const f16 as *const _,
(&beta) as *const f16 as *const _,
)
} else {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
(&alpha_f32) as *const f32 as *const _,
(&beta_f32) as *const f32 as *const _,
)
};
let alpha: f32 = cfg.gemm.alpha.to_f32();
let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@ -1905,7 +1915,7 @@ unsafe fn gemm_strided_batched_f16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const f32 as *const _,
alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
@ -1914,7 +1924,7 @@ unsafe fn gemm_strided_batched_f16(
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const f32 as *const _,
beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
@ -1935,14 +1945,26 @@ unsafe fn gemm_strided_batched_bf16(
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;
let compute_type = if gemm_reduced_precision_bf16() {
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
let beta_f32: f32 = cfg.gemm.beta.to_f32();
let alpha = f16::from_f32(alpha_f32);
let beta = f16::from_f32(beta_f32);
// The type for alpha and beta depends on the computeType.
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
(&alpha) as *const f16 as *const _,
(&beta) as *const f16 as *const _,
)
} else {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
(&alpha_f32) as *const f32 as *const _,
(&beta_f32) as *const f32 as *const _,
)
};
let alpha: f32 = cfg.gemm.alpha.to_f32();
let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@ -1950,7 +1972,7 @@ unsafe fn gemm_strided_batched_bf16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const f32 as *const _,
alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.lda,
@ -1959,7 +1981,7 @@ unsafe fn gemm_strided_batched_bf16(
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const f32 as *const _,
beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldc,