diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 07c5dfa8..4ec41b87 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -11,10 +11,6 @@ use cudarc::driver::{ use half::{bf16, f16}; use std::sync::{Arc, Mutex}; -const USE_IM2COL_CONV1D: bool = true; -#[cfg(not(feature = "cudnn"))] -const USE_IM2COL_CONV2D: bool = true; - /// cudarc related errors #[derive(thiserror::Error, Debug)] pub enum CudaError { @@ -1760,6 +1756,8 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv1D, ) -> Result { + const USE_IM2COL_CONV1D: bool = true; + let device = self.device().clone(); if !USE_IM2COL_CONV1D { let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; @@ -1808,6 +1806,8 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result { + const USE_IM2COL_CONV2D: bool = true; + let device = self.device().clone(); if !USE_IM2COL_CONV2D { let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;