mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
F16/BF16 bugfix (bis). (#2143)
* F16/BF16 bugfix (bis). * Another fix. * Yet another fix.
This commit is contained in:
@ -1890,14 +1890,24 @@ unsafe fn gemm_strided_batched_f16(
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let compute_type = if gemm_reduced_precision_f16() {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
|
||||
let alpha = cfg.gemm.alpha;
|
||||
let beta = cfg.gemm.beta;
|
||||
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||
let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
|
||||
(&alpha) as *const f16 as *const _,
|
||||
(&beta) as *const f16 as *const _,
|
||||
)
|
||||
} else {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
(&alpha_f32) as *const f32 as *const _,
|
||||
(&beta_f32) as *const f32 as *const _,
|
||||
)
|
||||
};
|
||||
|
||||
let alpha: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta: f32 = cfg.gemm.beta.to_f32();
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
@ -1905,7 +1915,7 @@ unsafe fn gemm_strided_batched_f16(
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
(&alpha) as *const f32 as *const _,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.lda,
|
||||
@ -1914,7 +1924,7 @@ unsafe fn gemm_strided_batched_f16(
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
(&beta) as *const f32 as *const _,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.ldc,
|
||||
@ -1935,14 +1945,26 @@ unsafe fn gemm_strided_batched_bf16(
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let compute_type = if gemm_reduced_precision_bf16() {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
|
||||
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||
let alpha = f16::from_f32(alpha_f32);
|
||||
let beta = f16::from_f32(beta_f32);
|
||||
// The type for alpha and beta depends on the computeType.
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
|
||||
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
|
||||
(&alpha) as *const f16 as *const _,
|
||||
(&beta) as *const f16 as *const _,
|
||||
)
|
||||
} else {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
(&alpha_f32) as *const f32 as *const _,
|
||||
(&beta_f32) as *const f32 as *const _,
|
||||
)
|
||||
};
|
||||
|
||||
let alpha: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta: f32 = cfg.gemm.beta.to_f32();
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
@ -1950,7 +1972,7 @@ unsafe fn gemm_strided_batched_bf16(
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
(&alpha) as *const f32 as *const _,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.lda,
|
||||
@ -1959,7 +1981,7 @@ unsafe fn gemm_strided_batched_bf16(
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
(&beta) as *const f32 as *const _,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.ldc,
|
||||
|
Reference in New Issue
Block a user