Expose the cudnn algo in the conv ops. (#2892)

* Set the algo.

* Expose the cudnn preferred algo for conv ops.
This commit is contained in:
Laurent Mazare
2025-04-14 08:25:32 +02:00
committed by GitHub
parent fb660b8d43
commit a52b76ae82
12 changed files with 63 additions and 26 deletions

View File

@ -6,28 +6,18 @@ extern crate intel_mkl_src;
use anyhow::Result; use anyhow::Result;
use candle_core::{Device, Tensor}; use candle_core::{Device, Tensor};
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
fn main() -> Result<()> { fn main() -> Result<()> {
let device = Device::new_cuda(0)?; let device = Device::new_cuda(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)? let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
.to_dtype(candle_core::DType::BF16)?; let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false); let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
drop(_x1); drop(_x1);
for _ in 0..20 {
let start_time = std::time::Instant::now();
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
device.synchronize()?;
println!("conv1d: {:?}", start_time.elapsed());
}
Ok(()) Ok(())
} }

View File

@ -55,7 +55,7 @@ impl ParamsConvTranspose1D {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo { pub enum CudnnFwdAlgo {
ImplicitGemm, ImplicitGemm,
ImplicitPrecompGemm, ImplicitPrecompGemm,
@ -152,6 +152,19 @@ impl Tensor {
stride: usize, stride: usize,
dilation: usize, dilation: usize,
groups: usize, groups: usize,
) -> Result<Self> {
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
}
/// Applies a 1D convolution over the input tensor.
pub fn conv1d_with_algo(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
) -> Result<Self> { ) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?; let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?; let (b_size, c_in, l_in) = self.dims3()?;
@ -175,7 +188,7 @@ impl Tensor {
padding, padding,
stride, stride,
dilation, dilation,
cudnn_fwd_algo: None, cudnn_fwd_algo,
}; };
if groups == 1 { if groups == 1 {
self.conv1d_single_group(kernel, &params) self.conv1d_single_group(kernel, &params)
@ -280,6 +293,18 @@ impl Tensor {
stride: usize, stride: usize,
dilation: usize, dilation: usize,
groups: usize, groups: usize,
) -> Result<Self> {
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
}
pub fn conv2d_with_algo(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
) -> Result<Self> { ) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?; let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
@ -299,7 +324,7 @@ impl Tensor {
padding, padding,
stride, stride,
dilation, dilation,
cudnn_fwd_algo: None, cudnn_fwd_algo,
}; };
if groups == 1 { if groups == 1 {
self.conv2d_single_group(kernel, &params) self.conv2d_single_group(kernel, &params)

View File

@ -133,6 +133,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
padding, padding,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}; };
let conv = if bias { let conv = if bias {
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))? conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?

View File

@ -92,6 +92,7 @@ impl ConvBlock {
stride, stride,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}; };
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;

View File

@ -1,6 +1,6 @@
//! Convolution Layers. //! Convolution Layers.
use crate::BatchNorm; use crate::BatchNorm;
use candle::{Result, Tensor}; use candle::{conv::CudnnFwdAlgo, Result, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv1dConfig { pub struct Conv1dConfig {
@ -8,6 +8,7 @@ pub struct Conv1dConfig {
pub stride: usize, pub stride: usize,
pub dilation: usize, pub dilation: usize,
pub groups: usize, pub groups: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
} }
impl Default for Conv1dConfig { impl Default for Conv1dConfig {
@ -17,6 +18,7 @@ impl Default for Conv1dConfig {
stride: 1, stride: 1,
dilation: 1, dilation: 1,
groups: 1, groups: 1,
cudnn_fwd_algo: None,
} }
} }
} }
@ -52,12 +54,13 @@ impl Conv1d {
impl crate::Module for Conv1d { impl crate::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d( let x = x.conv1d_with_algo(
&self.weight, &self.weight,
self.config.padding, self.config.padding,
self.config.stride, self.config.stride,
self.config.dilation, self.config.dilation,
self.config.groups, self.config.groups,
self.config.cudnn_fwd_algo,
)?; )?;
match &self.bias { match &self.bias {
None => Ok(x), None => Ok(x),
@ -147,6 +150,7 @@ pub struct Conv2dConfig {
pub stride: usize, pub stride: usize,
pub dilation: usize, pub dilation: usize,
pub groups: usize, pub groups: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
} }
impl Default for Conv2dConfig { impl Default for Conv2dConfig {
@ -156,6 +160,7 @@ impl Default for Conv2dConfig {
stride: 1, stride: 1,
dilation: 1, dilation: 1,
groups: 1, groups: 1,
cudnn_fwd_algo: None,
} }
} }
} }
@ -211,12 +216,13 @@ impl Conv2d {
impl crate::Module for Conv2d { impl crate::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d( let x = x.conv2d_with_algo(
&self.weight, &self.weight,
self.config.padding, self.config.padding,
self.config.stride, self.config.stride,
self.config.dilation, self.config.dilation,
self.config.groups, self.config.groups,
self.config.cudnn_fwd_algo,
)?; )?;
match &self.bias { match &self.bias {
None => Ok(x), None => Ok(x),

View File

@ -124,6 +124,7 @@ impl ResidualConvUnit {
stride: 1, stride: 1,
dilation: 1, dilation: 1,
groups: 1, groups: 1,
cudnn_fwd_algo: None,
}; };
let conv1 = conv2d( let conv1 = conv2d(
conf.num_features, conf.num_features,
@ -208,6 +209,7 @@ impl FeatureFusionBlock {
stride: 1, stride: 1,
dilation: 1, dilation: 1,
groups: 1, groups: 1,
cudnn_fwd_algo: None,
}; };
let output_conv = conv2d( let output_conv = conv2d(
conf.num_features, conf.num_features,
@ -258,6 +260,7 @@ impl Scratch {
stride: 1, stride: 1,
dilation: 1, dilation: 1,
groups: 1, groups: 1,
cudnn_fwd_algo: None,
}; };
let layer1_rn = conv2d_no_bias( let layer1_rn = conv2d_no_bias(
@ -319,6 +322,7 @@ impl Scratch {
stride: 1, stride: 1,
dilation: 1, dilation: 1,
groups: 1, groups: 1,
cudnn_fwd_algo: None,
}; };
let output_conv1 = conv2d( let output_conv1 = conv2d(
conf.num_features, conf.num_features,
@ -425,6 +429,7 @@ impl DPTHead {
stride: 2, stride: 2,
dilation: 1, dilation: 1,
groups: 1, groups: 1,
cudnn_fwd_algo: None,
}, },
vb.pp("resize_layers").pp("3"), vb.pp("resize_layers").pp("3"),
)?), )?),

View File

@ -468,6 +468,7 @@ impl EncodecConv1d {
stride, stride,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}, },
vb.pp("conv"), vb.pp("conv"),
)?, )?,

View File

@ -267,6 +267,7 @@ impl StreamableConv1d {
stride, stride,
dilation, dilation,
groups, groups,
cudnn_fwd_algo: None,
}; };
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?; let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
if k_size < stride { if k_size < stride {

View File

@ -68,6 +68,7 @@ impl ResnetBlock2D {
padding: 1, padding: 1,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}; };
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
@ -83,6 +84,7 @@ impl ResnetBlock2D {
padding: 0, padding: 0,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}; };
Some(conv2d( Some(conv2d(
in_channels, in_channels,

View File

@ -248,12 +248,14 @@ impl AudioEncoder {
stride: 1, stride: 1,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}; };
let cfg2 = Conv1dConfig { let cfg2 = Conv1dConfig {
padding: 1, padding: 1,
stride: 2, stride: 2,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}; };
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;

View File

@ -244,12 +244,14 @@ impl AudioEncoder {
stride: 1, stride: 1,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}; };
let cfg2 = Conv1dConfig { let cfg2 = Conv1dConfig {
padding: 1, padding: 1,
stride: 2, stride: 2,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}; };
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;

View File

@ -98,6 +98,7 @@ impl ConvBlock {
stride, stride,
groups: 1, groups: 1,
dilation: 1, dilation: 1,
cudnn_fwd_algo: None,
}; };
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;