mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Expose the cudnn algo in the conv ops. (#2892)
* Set the algo. * Expose the cudnn preferred algo for conv ops.
This commit is contained in:
@ -6,28 +6,18 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("fp32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("tf32: {:?}", start_time.elapsed());
|
||||
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
|
||||
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
|
||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||
drop(_x1);
|
||||
for _ in 0..20 {
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||
device.synchronize()?;
|
||||
println!("conv1d: {:?}", start_time.elapsed());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ impl ParamsConvTranspose1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
ImplicitPrecompGemm,
|
||||
@ -152,6 +152,19 @@ impl Tensor {
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d_with_algo(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
@ -175,7 +188,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
cudnn_fwd_algo,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
@ -280,6 +293,18 @@ impl Tensor {
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||
}
|
||||
|
||||
pub fn conv2d_with_algo(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
@ -299,7 +324,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
cudnn_fwd_algo,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
|
@ -133,6 +133,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
||||
padding,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = if bias {
|
||||
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||
|
@ -92,6 +92,7 @@ impl ConvBlock {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! Convolution Layers.
|
||||
use crate::BatchNorm;
|
||||
use candle::{Result, Tensor};
|
||||
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Conv1dConfig {
|
||||
@ -8,6 +8,7 @@ pub struct Conv1dConfig {
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl Default for Conv1dConfig {
|
||||
@ -17,6 +18,7 @@ impl Default for Conv1dConfig {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -52,12 +54,13 @@ impl Conv1d {
|
||||
|
||||
impl crate::Module for Conv1d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(
|
||||
let x = x.conv1d_with_algo(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
self.config.cudnn_fwd_algo,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
@ -147,6 +150,7 @@ pub struct Conv2dConfig {
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl Default for Conv2dConfig {
|
||||
@ -156,6 +160,7 @@ impl Default for Conv2dConfig {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -211,12 +216,13 @@ impl Conv2d {
|
||||
|
||||
impl crate::Module for Conv2d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv2d(
|
||||
let x = x.conv2d_with_algo(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
self.config.cudnn_fwd_algo,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
|
@ -124,6 +124,7 @@ impl ResidualConvUnit {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv1 = conv2d(
|
||||
conf.num_features,
|
||||
@ -208,6 +209,7 @@ impl FeatureFusionBlock {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let output_conv = conv2d(
|
||||
conf.num_features,
|
||||
@ -258,6 +260,7 @@ impl Scratch {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
|
||||
let layer1_rn = conv2d_no_bias(
|
||||
@ -319,6 +322,7 @@ impl Scratch {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let output_conv1 = conv2d(
|
||||
conf.num_features,
|
||||
@ -425,6 +429,7 @@ impl DPTHead {
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
},
|
||||
vb.pp("resize_layers").pp("3"),
|
||||
)?),
|
||||
|
@ -468,6 +468,7 @@ impl EncodecConv1d {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
|
@ -267,6 +267,7 @@ impl StreamableConv1d {
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
|
||||
if k_size < stride {
|
||||
|
@ -68,6 +68,7 @@ impl ResnetBlock2D {
|
||||
padding: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||
@ -83,6 +84,7 @@ impl ResnetBlock2D {
|
||||
padding: 0,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
Some(conv2d(
|
||||
in_channels,
|
||||
|
@ -248,12 +248,14 @@ impl AudioEncoder {
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
@ -244,12 +244,14 @@ impl AudioEncoder {
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
@ -98,6 +98,7 @@ impl ConvBlock {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
|
Reference in New Issue
Block a user