From dc47224ab9d34c8f4ea0e6ce87d964a030eae89c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 25 Sep 2023 10:31:53 +0100 Subject: [PATCH] Override the default cudnn heuristics. (#957) --- candle-core/src/conv.rs | 16 ++++++++++++++++ candle-core/src/cudnn.rs | 18 +++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 1f3ef582..f92c05b2 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -25,6 +25,20 @@ impl ParamsConv1D { } } +#[allow(unused)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum CudnnFwdAlgo { + ImplicitGemm, + ImplicitPrecompGemm, + Gemm, + Direct, + Fft, + FftTiling, + Winograd, + WinogradNonFused, + Count, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv2D { pub(crate) b_size: usize, @@ -37,6 +51,7 @@ pub struct ParamsConv2D { pub(crate) padding: usize, pub(crate) stride: usize, pub(crate) dilation: usize, + pub(crate) cudnn_fwd_algo: Option, } impl ParamsConv2D { @@ -188,6 +203,7 @@ impl Tensor { padding, stride, dilation, + cudnn_fwd_algo: None, }; if groups == 1 { self.conv2d_single_group(kernel, ¶ms) diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs index dd466ba2..0c149cd0 100644 --- a/candle-core/src/cudnn.rs +++ b/candle-core/src/cudnn.rs @@ -34,6 +34,9 @@ pub(crate) fn launch_conv2d< params: &crate::conv::ParamsConv2D, dev: &crate::cuda_backend::CudaDevice, ) -> crate::Result<()> { + use crate::conv::CudnnFwdAlgo as CandleAlgo; + use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A; + let device_id = dev.id(); let cudnn = CUDNN.with(|cudnn| { if let Some(cudnn) = cudnn.borrow().get(&device_id) { @@ -90,7 +93,20 @@ pub(crate) fn launch_conv2d< w: &w, y: &y, }; - let alg = conv2d.pick_algorithm()?; + let alg = match params.cudnn_fwd_algo { + None => conv2d.pick_algorithm()?, + Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + Some(CandleAlgo::ImplicitPrecompGemm) => { + A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM + } + Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT, + Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT, + }; let workspace_size = conv2d.get_workspace_size(alg)?; let mut workspace = dev.cuda_device().alloc_zeros::(workspace_size)?; unsafe {