mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Override the default cudnn heuristics. (#957)
This commit is contained in:
@ -25,6 +25,20 @@ impl ParamsConv1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum CudnnFwdAlgo {
|
||||||
|
ImplicitGemm,
|
||||||
|
ImplicitPrecompGemm,
|
||||||
|
Gemm,
|
||||||
|
Direct,
|
||||||
|
Fft,
|
||||||
|
FftTiling,
|
||||||
|
Winograd,
|
||||||
|
WinogradNonFused,
|
||||||
|
Count,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct ParamsConv2D {
|
pub struct ParamsConv2D {
|
||||||
pub(crate) b_size: usize,
|
pub(crate) b_size: usize,
|
||||||
@ -37,6 +51,7 @@ pub struct ParamsConv2D {
|
|||||||
pub(crate) padding: usize,
|
pub(crate) padding: usize,
|
||||||
pub(crate) stride: usize,
|
pub(crate) stride: usize,
|
||||||
pub(crate) dilation: usize,
|
pub(crate) dilation: usize,
|
||||||
|
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ParamsConv2D {
|
impl ParamsConv2D {
|
||||||
@ -188,6 +203,7 @@ impl Tensor {
|
|||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
if groups == 1 {
|
if groups == 1 {
|
||||||
self.conv2d_single_group(kernel, ¶ms)
|
self.conv2d_single_group(kernel, ¶ms)
|
||||||
|
@ -34,6 +34,9 @@ pub(crate) fn launch_conv2d<
|
|||||||
params: &crate::conv::ParamsConv2D,
|
params: &crate::conv::ParamsConv2D,
|
||||||
dev: &crate::cuda_backend::CudaDevice,
|
dev: &crate::cuda_backend::CudaDevice,
|
||||||
) -> crate::Result<()> {
|
) -> crate::Result<()> {
|
||||||
|
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||||
|
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||||
|
|
||||||
let device_id = dev.id();
|
let device_id = dev.id();
|
||||||
let cudnn = CUDNN.with(|cudnn| {
|
let cudnn = CUDNN.with(|cudnn| {
|
||||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||||
@ -90,7 +93,20 @@ pub(crate) fn launch_conv2d<
|
|||||||
w: &w,
|
w: &w,
|
||||||
y: &y,
|
y: &y,
|
||||||
};
|
};
|
||||||
let alg = conv2d.pick_algorithm()?;
|
let alg = match params.cudnn_fwd_algo {
|
||||||
|
None => conv2d.pick_algorithm()?,
|
||||||
|
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||||
|
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||||
|
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||||
|
}
|
||||||
|
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||||
|
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||||
|
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||||
|
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||||
|
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||||
|
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||||
|
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||||
|
};
|
||||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||||
unsafe {
|
unsafe {
|
||||||
|
Reference in New Issue
Block a user