From a52b76ae82301200d73c331af8e878855f939019 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 14 Apr 2025 08:25:32 +0200 Subject: [PATCH] Expose the cudnn algo in the conv ops. (#2892) * Set the algo. * Expose the cudnn preferred algo for conv ops. --- candle-core/examples/cuda_basics.rs | 30 ++++++------------ candle-core/src/conv.rs | 31 +++++++++++++++++-- candle-examples/examples/yolo-v3/darknet.rs | 1 + candle-examples/examples/yolo-v8/model.rs | 1 + candle-nn/src/conv.rs | 12 +++++-- .../src/models/depth_anything_v2.rs | 5 +++ candle-transformers/src/models/encodec.rs | 1 + candle-transformers/src/models/mimi/conv.rs | 1 + .../src/models/stable_diffusion/resnet.rs | 2 ++ .../src/models/whisper/model.rs | 2 ++ .../src/models/whisper/quantized_model.rs | 2 ++ candle-wasm-examples/yolo/src/model.rs | 1 + 12 files changed, 63 insertions(+), 26 deletions(-) diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 9af1b006..4eadcdeb 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -6,28 +6,18 @@ extern crate intel_mkl_src; use anyhow::Result; use candle_core::{Device, Tensor}; - +// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 } fn main() -> Result<()> { let device = Device::new_cuda(0)?; - let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)? - .to_dtype(candle_core::DType::BF16)?; - candle_core::cuda::set_gemm_reduced_precision_f32(false); - candle_core::cuda::set_gemm_reduced_precision_bf16(false); - let _x1 = x.matmul(&x)?; - drop(_x1); - let start_time = std::time::Instant::now(); - let _x1 = x.matmul(&x)?; - device.synchronize()?; - println!("fp32: {:?}", start_time.elapsed()); - drop(_x1); - candle_core::cuda::set_gemm_reduced_precision_f32(true); - candle_core::cuda::set_gemm_reduced_precision_bf16(true); - let _x1 = x.matmul(&x)?; - drop(_x1); - let start_time = std::time::Instant::now(); - let _x1 = x.matmul(&x)?; - device.synchronize()?; - println!("tf32: {:?}", start_time.elapsed()); + let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?; + let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?; + let _x1 = x.conv1d(&c, 0, 4, 1, 1)?; drop(_x1); + for _ in 0..20 { + let start_time = std::time::Instant::now(); + let _x1 = x.conv1d(&c, 0, 4, 1, 1)?; + device.synchronize()?; + println!("conv1d: {:?}", start_time.elapsed()); + } Ok(()) } diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 3ec7daa4..115035ef 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -55,7 +55,7 @@ impl ParamsConvTranspose1D { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum CudnnFwdAlgo { ImplicitGemm, ImplicitPrecompGemm, @@ -152,6 +152,19 @@ impl Tensor { stride: usize, dilation: usize, groups: usize, + ) -> Result { + self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None) + } + + /// Applies a 1D convolution over the input tensor. + pub fn conv1d_with_algo( + &self, + kernel: &Self, + padding: usize, + stride: usize, + dilation: usize, + groups: usize, + cudnn_fwd_algo: Option, ) -> Result { let (c_out, c_in_k, k_size) = kernel.dims3()?; let (b_size, c_in, l_in) = self.dims3()?; @@ -175,7 +188,7 @@ impl Tensor { padding, stride, dilation, - cudnn_fwd_algo: None, + cudnn_fwd_algo, }; if groups == 1 { self.conv1d_single_group(kernel, ¶ms) @@ -280,6 +293,18 @@ impl Tensor { stride: usize, dilation: usize, groups: usize, + ) -> Result { + self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None) + } + + pub fn conv2d_with_algo( + &self, + kernel: &Self, + padding: usize, + stride: usize, + dilation: usize, + groups: usize, + cudnn_fwd_algo: Option, ) -> Result { let (b_size, c_in, i_h, i_w) = self.dims4()?; let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; @@ -299,7 +324,7 @@ impl Tensor { padding, stride, dilation, - cudnn_fwd_algo: None, + cudnn_fwd_algo, }; if groups == 1 { self.conv2d_single_group(kernel, ¶ms) diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs index 944f4dcb..a33087c5 100644 --- a/candle-examples/examples/yolo-v3/darknet.rs +++ b/candle-examples/examples/yolo-v3/darknet.rs @@ -133,6 +133,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl) padding, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }; let conv = if bias { conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))? diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index e1be1f3c..dc13bb97 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -92,6 +92,7 @@ impl ConvBlock { stride, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?; diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index c183e6b9..6b01c2c6 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -1,6 +1,6 @@ //! Convolution Layers. use crate::BatchNorm; -use candle::{Result, Tensor}; +use candle::{conv::CudnnFwdAlgo, Result, Tensor}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Conv1dConfig { @@ -8,6 +8,7 @@ pub struct Conv1dConfig { pub stride: usize, pub dilation: usize, pub groups: usize, + pub cudnn_fwd_algo: Option, } impl Default for Conv1dConfig { @@ -17,6 +18,7 @@ impl Default for Conv1dConfig { stride: 1, dilation: 1, groups: 1, + cudnn_fwd_algo: None, } } } @@ -52,12 +54,13 @@ impl Conv1d { impl crate::Module for Conv1d { fn forward(&self, x: &Tensor) -> Result { - let x = x.conv1d( + let x = x.conv1d_with_algo( &self.weight, self.config.padding, self.config.stride, self.config.dilation, self.config.groups, + self.config.cudnn_fwd_algo, )?; match &self.bias { None => Ok(x), @@ -147,6 +150,7 @@ pub struct Conv2dConfig { pub stride: usize, pub dilation: usize, pub groups: usize, + pub cudnn_fwd_algo: Option, } impl Default for Conv2dConfig { @@ -156,6 +160,7 @@ impl Default for Conv2dConfig { stride: 1, dilation: 1, groups: 1, + cudnn_fwd_algo: None, } } } @@ -211,12 +216,13 @@ impl Conv2d { impl crate::Module for Conv2d { fn forward(&self, x: &Tensor) -> Result { - let x = x.conv2d( + let x = x.conv2d_with_algo( &self.weight, self.config.padding, self.config.stride, self.config.dilation, self.config.groups, + self.config.cudnn_fwd_algo, )?; match &self.bias { None => Ok(x), diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs index 3b6bd1a5..690d396b 100644 --- a/candle-transformers/src/models/depth_anything_v2.rs +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -124,6 +124,7 @@ impl ResidualConvUnit { stride: 1, dilation: 1, groups: 1, + cudnn_fwd_algo: None, }; let conv1 = conv2d( conf.num_features, @@ -208,6 +209,7 @@ impl FeatureFusionBlock { stride: 1, dilation: 1, groups: 1, + cudnn_fwd_algo: None, }; let output_conv = conv2d( conf.num_features, @@ -258,6 +260,7 @@ impl Scratch { stride: 1, dilation: 1, groups: 1, + cudnn_fwd_algo: None, }; let layer1_rn = conv2d_no_bias( @@ -319,6 +322,7 @@ impl Scratch { stride: 1, dilation: 1, groups: 1, + cudnn_fwd_algo: None, }; let output_conv1 = conv2d( conf.num_features, @@ -425,6 +429,7 @@ impl DPTHead { stride: 2, dilation: 1, groups: 1, + cudnn_fwd_algo: None, }, vb.pp("resize_layers").pp("3"), )?), diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index 7ed1fcec..4bea97b9 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -468,6 +468,7 @@ impl EncodecConv1d { stride, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }, vb.pp("conv"), )?, diff --git a/candle-transformers/src/models/mimi/conv.rs b/candle-transformers/src/models/mimi/conv.rs index 87e9fb4c..695c0de6 100644 --- a/candle-transformers/src/models/mimi/conv.rs +++ b/candle-transformers/src/models/mimi/conv.rs @@ -267,6 +267,7 @@ impl StreamableConv1d { stride, dilation, groups, + cudnn_fwd_algo: None, }; let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?; if k_size < stride { diff --git a/candle-transformers/src/models/stable_diffusion/resnet.rs b/candle-transformers/src/models/stable_diffusion/resnet.rs index 5cca7edd..8a6490c5 100644 --- a/candle-transformers/src/models/stable_diffusion/resnet.rs +++ b/candle-transformers/src/models/stable_diffusion/resnet.rs @@ -68,6 +68,7 @@ impl ResnetBlock2D { padding: 1, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }; let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; @@ -83,6 +84,7 @@ impl ResnetBlock2D { padding: 0, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }; Some(conv2d( in_channels, diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index dc50e0db..2f34b180 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -248,12 +248,14 @@ impl AudioEncoder { stride: 1, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }; let cfg2 = Conv1dConfig { padding: 1, stride: 2, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs index 2db363c6..15130fbd 100644 --- a/candle-transformers/src/models/whisper/quantized_model.rs +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -244,12 +244,14 @@ impl AudioEncoder { stride: 1, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }; let cfg2 = Conv1dConfig { padding: 1, stride: 2, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs index ee98c125..c52dcc80 100644 --- a/candle-wasm-examples/yolo/src/model.rs +++ b/candle-wasm-examples/yolo/src/model.rs @@ -98,6 +98,7 @@ impl ConvBlock { stride, groups: 1, dilation: 1, + cudnn_fwd_algo: None, }; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;