mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Make it possible to use TF32 accumulation in F32 matmuls. (#2178)
* Allow the use of tf32 accumulation in matmul. * Better timings. * Dummy versions for use when cuda is not enabled.
This commit is contained in:
@ -5,32 +5,26 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Module, Tensor};
|
||||
|
||||
use candle_core::quantized::{QMatMul, QTensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q = QMatMul::from_qtensor(q)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
||||
let res_q_cuda = q.forward(&x)?;
|
||||
println!("{res_q_cuda}");
|
||||
|
||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
||||
let res_q_cpu = q_cpu.forward(&x_cpu)?;
|
||||
println!("{res_q_cpu}");
|
||||
|
||||
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
|
||||
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?;
|
||||
println!("{diff}");
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("fp32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("tf32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1615,12 +1615,8 @@ impl BackendStorage for CudaStorage {
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}
|
||||
.w()?;
|
||||
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
@ -1817,6 +1813,20 @@ static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn gemm_reduced_precision_f32() -> bool {
|
||||
MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f32(b: bool) {
|
||||
MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
@ -1842,6 +1852,51 @@ pub fn set_gemm_reduced_precision_bf16(b: bool) {
|
||||
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_f32(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<f32>,
|
||||
a: &cudarc::driver::CudaView<f32>,
|
||||
b: &cudarc::driver::CudaView<f32>,
|
||||
c: &mut CudaSlice<f32>,
|
||||
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let compute_type = if gemm_reduced_precision_f32() {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
} else {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||
};
|
||||
let alpha = &cfg.gemm.alpha as *const f32 as *const _;
|
||||
let beta = &cfg.gemm.beta as *const f32 as *const _;
|
||||
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
cfg.gemm.transb,
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.lda,
|
||||
cfg.stride_a,
|
||||
*b.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.ldc,
|
||||
cfg.stride_c,
|
||||
cfg.batch_size,
|
||||
compute_type,
|
||||
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_f16(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<f16>,
|
||||
|
@ -258,3 +258,13 @@ pub fn gemm_reduced_precision_bf16() -> bool {
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn gemm_reduced_precision_f32() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f32(_b: bool) {}
|
||||
|
Reference in New Issue
Block a user