mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Make it possible to use TF32 accumulation in F32 matmuls. (#2178)
* Allow the use of tf32 accumulation in matmul. * Better timings. * Dummy versions for use when cuda is not enabled.
This commit is contained in:
@ -1615,12 +1615,8 @@ impl BackendStorage for CudaStorage {
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}
|
||||
.w()?;
|
||||
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
@ -1817,6 +1813,20 @@ static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn gemm_reduced_precision_f32() -> bool {
|
||||
MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f32(b: bool) {
|
||||
MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
@ -1842,6 +1852,51 @@ pub fn set_gemm_reduced_precision_bf16(b: bool) {
|
||||
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_f32(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<f32>,
|
||||
a: &cudarc::driver::CudaView<f32>,
|
||||
b: &cudarc::driver::CudaView<f32>,
|
||||
c: &mut CudaSlice<f32>,
|
||||
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let compute_type = if gemm_reduced_precision_f32() {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
} else {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||
};
|
||||
let alpha = &cfg.gemm.alpha as *const f32 as *const _;
|
||||
let beta = &cfg.gemm.beta as *const f32 as *const _;
|
||||
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
cfg.gemm.transb,
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.lda,
|
||||
cfg.stride_a,
|
||||
*b.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.ldc,
|
||||
cfg.stride_c,
|
||||
cfg.batch_size,
|
||||
compute_type,
|
||||
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_f16(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<f16>,
|
||||
|
Reference in New Issue
Block a user