diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 7929fba6..00e937cb 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -5,32 +5,26 @@ extern crate accelerate_src; extern crate intel_mkl_src; use anyhow::Result; -use candle_core::{Device, Module, Tensor}; - -use candle_core::quantized::{QMatMul, QTensor}; +use candle_core::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; - let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?; - let q_cpu = q.to_device(&Device::Cpu)?; - let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?; - let q = QMatMul::from_qtensor(q)?; - let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?; - let res_q_cuda = q.forward(&x)?; - println!("{res_q_cuda}"); - - let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?; - let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?; - let q_cpu = QMatMul::from_qtensor(q_cpu)?; - let x_cpu = x.to_device(&Device::Cpu)?; - let res_q_cpu = q_cpu.forward(&x_cpu)?; - println!("{res_q_cpu}"); - - let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?; - let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))? - .abs()? - .flatten_all()? - .max(0)?; - println!("{diff}"); + let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?; + candle_core::cuda::set_gemm_reduced_precision_f32(false); + let _x1 = x.matmul(&x)?; + drop(_x1); + let start_time = std::time::Instant::now(); + let _x1 = x.matmul(&x)?; + device.synchronize()?; + println!("fp32: {:?}", start_time.elapsed()); + drop(_x1); + candle_core::cuda::set_gemm_reduced_precision_f32(true); + let _x1 = x.matmul(&x)?; + drop(_x1); + let start_time = std::time::Instant::now(); + let _x1 = x.matmul(&x)?; + device.synchronize()?; + println!("tf32: {:?}", start_time.elapsed()); + drop(_x1); Ok(()) } diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 0b9c7a2c..1ea9beaf 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1615,12 +1615,8 @@ impl BackendStorage for CudaStorage { let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; - unsafe { - self.device - .blas - .gemm_strided_batched(cfg, rhs, lhs, &mut out) - } - .w()?; + unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) } + .w()?; CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { @@ -1817,6 +1813,20 @@ static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); +static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool = + std::sync::atomic::AtomicBool::new(false); + +/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are +/// allowed with f32 GEMMs. +pub fn gemm_reduced_precision_f32() -> bool { + MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed) +} + +/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are +/// allowed with f32 GEMMs. +pub fn set_gemm_reduced_precision_f32(b: bool) { + MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed) +} /// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are /// allowed with f16 GEMMs. @@ -1842,6 +1852,51 @@ pub fn set_gemm_reduced_precision_bf16(b: bool) { MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed) } +unsafe fn gemm_strided_batched_f32( + cublas: &cudarc::cublas::CudaBlas, + cfg: StridedBatchedConfig, + a: &cudarc::driver::CudaView, + b: &cudarc::driver::CudaView, + c: &mut CudaSlice, +) -> std::result::Result<(), cudarc::cublas::result::CublasError> { + use cudarc::cublas::sys; + use cudarc::driver::DevicePtrMut; + + let compute_type = if gemm_reduced_precision_f32() { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 + } else { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + }; + let alpha = &cfg.gemm.alpha as *const f32 as *const _; + let beta = &cfg.gemm.beta as *const f32 as *const _; + + cudarc::cublas::result::gemm_strided_batched_ex( + *cublas.handle(), + cfg.gemm.transa, + cfg.gemm.transb, + cfg.gemm.m, + cfg.gemm.n, + cfg.gemm.k, + alpha, + *a.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_32F, + cfg.gemm.lda, + cfg.stride_a, + *b.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_32F, + cfg.gemm.ldb, + cfg.stride_b, + beta, + *c.device_ptr_mut() as *mut _, + sys::cudaDataType_t::CUDA_R_32F, + cfg.gemm.ldc, + cfg.stride_c, + cfg.batch_size, + compute_type, + sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, + ) +} + unsafe fn gemm_strided_batched_f16( cublas: &cudarc::cublas::CudaBlas, cfg: StridedBatchedConfig, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 90c2fe48..68eef1ef 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -258,3 +258,13 @@ pub fn gemm_reduced_precision_bf16() -> bool { /// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are /// allowed with bf16 GEMMs. pub fn set_gemm_reduced_precision_bf16(_: bool) {} + +/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are +/// allowed with f32 GEMMs. +pub fn gemm_reduced_precision_f32() -> bool { + true +} + +/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are +/// allowed with f32 GEMMs. +pub fn set_gemm_reduced_precision_f32(_b: bool) {}