mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Bugfix the recent f16/bf16 changes. (#2142)
This commit is contained in:
@ -1896,8 +1896,8 @@ unsafe fn gemm_strided_batched_f16(
|
|||||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||||
};
|
};
|
||||||
|
|
||||||
let alpha = cfg.gemm.alpha;
|
let alpha: f32 = cfg.gemm.alpha.to_f32();
|
||||||
let beta = cfg.gemm.beta;
|
let beta: f32 = cfg.gemm.beta.to_f32();
|
||||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||||
*cublas.handle(),
|
*cublas.handle(),
|
||||||
cfg.gemm.transa,
|
cfg.gemm.transa,
|
||||||
@ -1905,7 +1905,7 @@ unsafe fn gemm_strided_batched_f16(
|
|||||||
cfg.gemm.m,
|
cfg.gemm.m,
|
||||||
cfg.gemm.n,
|
cfg.gemm.n,
|
||||||
cfg.gemm.k,
|
cfg.gemm.k,
|
||||||
(&alpha) as *const f16 as *const _,
|
(&alpha) as *const f32 as *const _,
|
||||||
*a.device_ptr() as *const _,
|
*a.device_ptr() as *const _,
|
||||||
sys::cudaDataType_t::CUDA_R_16F,
|
sys::cudaDataType_t::CUDA_R_16F,
|
||||||
cfg.gemm.lda,
|
cfg.gemm.lda,
|
||||||
@ -1914,7 +1914,7 @@ unsafe fn gemm_strided_batched_f16(
|
|||||||
sys::cudaDataType_t::CUDA_R_16F,
|
sys::cudaDataType_t::CUDA_R_16F,
|
||||||
cfg.gemm.ldb,
|
cfg.gemm.ldb,
|
||||||
cfg.stride_b,
|
cfg.stride_b,
|
||||||
(&beta) as *const f16 as *const _,
|
(&beta) as *const f32 as *const _,
|
||||||
*c.device_ptr_mut() as *mut _,
|
*c.device_ptr_mut() as *mut _,
|
||||||
sys::cudaDataType_t::CUDA_R_16F,
|
sys::cudaDataType_t::CUDA_R_16F,
|
||||||
cfg.gemm.ldc,
|
cfg.gemm.ldc,
|
||||||
@ -1941,8 +1941,8 @@ unsafe fn gemm_strided_batched_bf16(
|
|||||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||||
};
|
};
|
||||||
|
|
||||||
let alpha = cfg.gemm.alpha;
|
let alpha: f32 = cfg.gemm.alpha.to_f32();
|
||||||
let beta = cfg.gemm.beta;
|
let beta: f32 = cfg.gemm.beta.to_f32();
|
||||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||||
*cublas.handle(),
|
*cublas.handle(),
|
||||||
cfg.gemm.transa,
|
cfg.gemm.transa,
|
||||||
@ -1950,7 +1950,7 @@ unsafe fn gemm_strided_batched_bf16(
|
|||||||
cfg.gemm.m,
|
cfg.gemm.m,
|
||||||
cfg.gemm.n,
|
cfg.gemm.n,
|
||||||
cfg.gemm.k,
|
cfg.gemm.k,
|
||||||
(&alpha) as *const bf16 as *const _,
|
(&alpha) as *const f32 as *const _,
|
||||||
*a.device_ptr() as *const _,
|
*a.device_ptr() as *const _,
|
||||||
sys::cudaDataType_t::CUDA_R_16BF,
|
sys::cudaDataType_t::CUDA_R_16BF,
|
||||||
cfg.gemm.lda,
|
cfg.gemm.lda,
|
||||||
@ -1959,7 +1959,7 @@ unsafe fn gemm_strided_batched_bf16(
|
|||||||
sys::cudaDataType_t::CUDA_R_16BF,
|
sys::cudaDataType_t::CUDA_R_16BF,
|
||||||
cfg.gemm.ldb,
|
cfg.gemm.ldb,
|
||||||
cfg.stride_b,
|
cfg.stride_b,
|
||||||
(&beta) as *const bf16 as *const _,
|
(&beta) as *const f32 as *const _,
|
||||||
*c.device_ptr_mut() as *mut _,
|
*c.device_ptr_mut() as *mut _,
|
||||||
sys::cudaDataType_t::CUDA_R_16BF,
|
sys::cudaDataType_t::CUDA_R_16BF,
|
||||||
cfg.gemm.ldc,
|
cfg.gemm.ldc,
|
||||||
|
Reference in New Issue
Block a user