mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
F16/BF16 bugfix (bis). (#2143)
* F16/BF16 bugfix (bis). * Another fix. * Yet another fix.
This commit is contained in:
@ -1890,14 +1890,24 @@ unsafe fn gemm_strided_batched_f16(
|
|||||||
use cudarc::cublas::sys;
|
use cudarc::cublas::sys;
|
||||||
use cudarc::driver::DevicePtrMut;
|
use cudarc::driver::DevicePtrMut;
|
||||||
|
|
||||||
let compute_type = if gemm_reduced_precision_f16() {
|
let alpha = cfg.gemm.alpha;
|
||||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
|
let beta = cfg.gemm.beta;
|
||||||
|
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||||
|
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||||
|
let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() {
|
||||||
|
(
|
||||||
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
|
||||||
|
(&alpha) as *const f16 as *const _,
|
||||||
|
(&beta) as *const f16 as *const _,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
(
|
||||||
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||||
|
(&alpha_f32) as *const f32 as *const _,
|
||||||
|
(&beta_f32) as *const f32 as *const _,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let alpha: f32 = cfg.gemm.alpha.to_f32();
|
|
||||||
let beta: f32 = cfg.gemm.beta.to_f32();
|
|
||||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||||
*cublas.handle(),
|
*cublas.handle(),
|
||||||
cfg.gemm.transa,
|
cfg.gemm.transa,
|
||||||
@ -1905,7 +1915,7 @@ unsafe fn gemm_strided_batched_f16(
|
|||||||
cfg.gemm.m,
|
cfg.gemm.m,
|
||||||
cfg.gemm.n,
|
cfg.gemm.n,
|
||||||
cfg.gemm.k,
|
cfg.gemm.k,
|
||||||
(&alpha) as *const f32 as *const _,
|
alpha,
|
||||||
*a.device_ptr() as *const _,
|
*a.device_ptr() as *const _,
|
||||||
sys::cudaDataType_t::CUDA_R_16F,
|
sys::cudaDataType_t::CUDA_R_16F,
|
||||||
cfg.gemm.lda,
|
cfg.gemm.lda,
|
||||||
@ -1914,7 +1924,7 @@ unsafe fn gemm_strided_batched_f16(
|
|||||||
sys::cudaDataType_t::CUDA_R_16F,
|
sys::cudaDataType_t::CUDA_R_16F,
|
||||||
cfg.gemm.ldb,
|
cfg.gemm.ldb,
|
||||||
cfg.stride_b,
|
cfg.stride_b,
|
||||||
(&beta) as *const f32 as *const _,
|
beta,
|
||||||
*c.device_ptr_mut() as *mut _,
|
*c.device_ptr_mut() as *mut _,
|
||||||
sys::cudaDataType_t::CUDA_R_16F,
|
sys::cudaDataType_t::CUDA_R_16F,
|
||||||
cfg.gemm.ldc,
|
cfg.gemm.ldc,
|
||||||
@ -1935,14 +1945,26 @@ unsafe fn gemm_strided_batched_bf16(
|
|||||||
use cudarc::cublas::sys;
|
use cudarc::cublas::sys;
|
||||||
use cudarc::driver::DevicePtrMut;
|
use cudarc::driver::DevicePtrMut;
|
||||||
|
|
||||||
let compute_type = if gemm_reduced_precision_bf16() {
|
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
|
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||||
|
let alpha = f16::from_f32(alpha_f32);
|
||||||
|
let beta = f16::from_f32(beta_f32);
|
||||||
|
// The type for alpha and beta depends on the computeType.
|
||||||
|
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
|
||||||
|
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
|
||||||
|
(
|
||||||
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
|
||||||
|
(&alpha) as *const f16 as *const _,
|
||||||
|
(&beta) as *const f16 as *const _,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
(
|
||||||
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||||
|
(&alpha_f32) as *const f32 as *const _,
|
||||||
|
(&beta_f32) as *const f32 as *const _,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let alpha: f32 = cfg.gemm.alpha.to_f32();
|
|
||||||
let beta: f32 = cfg.gemm.beta.to_f32();
|
|
||||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||||
*cublas.handle(),
|
*cublas.handle(),
|
||||||
cfg.gemm.transa,
|
cfg.gemm.transa,
|
||||||
@ -1950,7 +1972,7 @@ unsafe fn gemm_strided_batched_bf16(
|
|||||||
cfg.gemm.m,
|
cfg.gemm.m,
|
||||||
cfg.gemm.n,
|
cfg.gemm.n,
|
||||||
cfg.gemm.k,
|
cfg.gemm.k,
|
||||||
(&alpha) as *const f32 as *const _,
|
alpha,
|
||||||
*a.device_ptr() as *const _,
|
*a.device_ptr() as *const _,
|
||||||
sys::cudaDataType_t::CUDA_R_16BF,
|
sys::cudaDataType_t::CUDA_R_16BF,
|
||||||
cfg.gemm.lda,
|
cfg.gemm.lda,
|
||||||
@ -1959,7 +1981,7 @@ unsafe fn gemm_strided_batched_bf16(
|
|||||||
sys::cudaDataType_t::CUDA_R_16BF,
|
sys::cudaDataType_t::CUDA_R_16BF,
|
||||||
cfg.gemm.ldb,
|
cfg.gemm.ldb,
|
||||||
cfg.stride_b,
|
cfg.stride_b,
|
||||||
(&beta) as *const f32 as *const _,
|
beta,
|
||||||
*c.device_ptr_mut() as *mut _,
|
*c.device_ptr_mut() as *mut _,
|
||||||
sys::cudaDataType_t::CUDA_R_16BF,
|
sys::cudaDataType_t::CUDA_R_16BF,
|
||||||
cfg.gemm.ldc,
|
cfg.gemm.ldc,
|
||||||
|
Reference in New Issue
Block a user