mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Support for cudnn conv1d. (#2888)
* Support for cudnn conv1d. * More conv1d work. * Get the conv1d to work with cudnn. * Cleanup.
This commit is contained in:
@ -14,6 +14,7 @@ pub struct ParamsConv1D {
|
|||||||
pub(crate) padding: usize,
|
pub(crate) padding: usize,
|
||||||
pub(crate) stride: usize,
|
pub(crate) stride: usize,
|
||||||
pub(crate) dilation: usize,
|
pub(crate) dilation: usize,
|
||||||
|
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ParamsConv1D {
|
impl ParamsConv1D {
|
||||||
@ -174,6 +175,7 @@ impl Tensor {
|
|||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
if groups == 1 {
|
if groups == 1 {
|
||||||
self.conv1d_single_group(kernel, ¶ms)
|
self.conv1d_single_group(kernel, ¶ms)
|
||||||
|
@ -122,3 +122,104 @@ pub(crate) fn launch_conv2d<
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn launch_conv1d<
|
||||||
|
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||||
|
Y: cudarc::cudnn::CudnnDataType,
|
||||||
|
>(
|
||||||
|
src: &CudaView<T>,
|
||||||
|
src_l: &crate::Layout,
|
||||||
|
filter: &CudaView<T>,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
params: &crate::conv::ParamsConv1D,
|
||||||
|
dev: &crate::cuda_backend::CudaDevice,
|
||||||
|
) -> crate::Result<()> {
|
||||||
|
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||||
|
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||||
|
|
||||||
|
let device_id = dev.id();
|
||||||
|
let cudnn = CUDNN.with(|cudnn| {
|
||||||
|
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||||
|
return Ok(cudnn.clone());
|
||||||
|
}
|
||||||
|
let c = Cudnn::new(dev.cuda_stream());
|
||||||
|
if let Ok(c) = &c {
|
||||||
|
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||||
|
}
|
||||||
|
c
|
||||||
|
})?;
|
||||||
|
let conv = cudnn.create_conv2d::<Y>(
|
||||||
|
/* pad */ [params.padding as i32, 0],
|
||||||
|
/* stride */ [params.stride as i32, 1],
|
||||||
|
/* dilation */ [params.dilation as i32, 1],
|
||||||
|
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||||
|
)?;
|
||||||
|
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
|
||||||
|
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
|
||||||
|
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
|
||||||
|
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
|
||||||
|
// > to 1.
|
||||||
|
let x_shape = [
|
||||||
|
params.b_size as i32,
|
||||||
|
params.c_in as i32,
|
||||||
|
params.l_in as i32,
|
||||||
|
1,
|
||||||
|
];
|
||||||
|
// Note that `src` already starts at the proper offset.
|
||||||
|
let x = if src_l.is_contiguous() {
|
||||||
|
cudnn.create_4d_tensor::<T>(
|
||||||
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
|
x_shape,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
let s = src_l.stride();
|
||||||
|
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
|
||||||
|
};
|
||||||
|
let w = cudnn.create_4d_filter::<T>(
|
||||||
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
|
[
|
||||||
|
params.c_out as i32,
|
||||||
|
params.c_in as i32,
|
||||||
|
params.k_size as i32,
|
||||||
|
1,
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
let l_out = params.l_out() as i32;
|
||||||
|
let y = cudnn.create_4d_tensor::<T>(
|
||||||
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
|
[params.b_size as i32, params.c_out as i32, l_out, 1],
|
||||||
|
)?;
|
||||||
|
let conv1d = ConvForward {
|
||||||
|
conv: &conv,
|
||||||
|
x: &x,
|
||||||
|
w: &w,
|
||||||
|
y: &y,
|
||||||
|
};
|
||||||
|
let alg = match params.cudnn_fwd_algo {
|
||||||
|
None => conv1d.pick_algorithm()?,
|
||||||
|
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||||
|
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||||
|
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||||
|
}
|
||||||
|
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||||
|
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||||
|
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||||
|
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||||
|
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||||
|
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||||
|
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||||
|
};
|
||||||
|
let workspace_size = conv1d.get_workspace_size(alg)?;
|
||||||
|
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||||
|
unsafe {
|
||||||
|
conv1d.launch::<CudaSlice<u8>, _, _, _>(
|
||||||
|
alg,
|
||||||
|
Some(&mut workspace),
|
||||||
|
(T::one(), T::zero()),
|
||||||
|
src,
|
||||||
|
filter,
|
||||||
|
dst,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -134,6 +134,7 @@ impl Map1 for Elu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
struct Im2Col1D {
|
struct Im2Col1D {
|
||||||
l_k: usize,
|
l_k: usize,
|
||||||
stride: usize,
|
stride: usize,
|
||||||
@ -142,6 +143,7 @@ struct Im2Col1D {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Im2Col1D {
|
impl Im2Col1D {
|
||||||
|
#[allow(unused)]
|
||||||
fn l_out(&self, l: usize) -> usize {
|
fn l_out(&self, l: usize) -> usize {
|
||||||
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||||
}
|
}
|
||||||
@ -1435,6 +1437,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cudnn"))]
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
@ -1485,6 +1488,72 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cudnn")]
|
||||||
|
fn conv1d(
|
||||||
|
&self,
|
||||||
|
inp_l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
params: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let device = self.device().clone();
|
||||||
|
if !kernel_l.is_contiguous() {
|
||||||
|
let slice = Conv1D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
|
||||||
|
return Ok(Self { slice, device });
|
||||||
|
}
|
||||||
|
let l_out = params.l_out();
|
||||||
|
let dst_el = params.c_out * l_out * params.b_size;
|
||||||
|
let slice = match (&self.slice, &kernel.slice) {
|
||||||
|
(S::U8(inp), S::U8(k)) => {
|
||||||
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
|
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
|
||||||
|
crate::cudnn::launch_conv1d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
|
||||||
|
.map_err(crate::Error::wrap)?;
|
||||||
|
S::U8(out)
|
||||||
|
}
|
||||||
|
(S::BF16(inp), S::BF16(k)) => {
|
||||||
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
|
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
|
||||||
|
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
|
||||||
|
// version.
|
||||||
|
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
|
||||||
|
crate::cudnn::launch_conv1d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||||
|
.map_err(crate::Error::wrap)?;
|
||||||
|
S::BF16(out)
|
||||||
|
}
|
||||||
|
(S::F16(inp), S::F16(k)) => {
|
||||||
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
|
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
|
||||||
|
crate::cudnn::launch_conv1d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
|
||||||
|
.map_err(crate::Error::wrap)?;
|
||||||
|
S::F16(out)
|
||||||
|
}
|
||||||
|
(S::F32(inp), S::F32(k)) => {
|
||||||
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
|
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
|
||||||
|
crate::cudnn::launch_conv1d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||||
|
.map_err(crate::Error::wrap)?;
|
||||||
|
S::F32(out)
|
||||||
|
}
|
||||||
|
(S::F64(inp), S::F64(k)) => {
|
||||||
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
|
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
|
||||||
|
crate::cudnn::launch_conv1d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
|
||||||
|
.map_err(crate::Error::wrap)?;
|
||||||
|
S::F64(out)
|
||||||
|
}
|
||||||
|
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv1d does not support u32"))?,
|
||||||
|
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv1d does not support i64"))?,
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in conv1d"))?,
|
||||||
|
};
|
||||||
|
Ok(Self { slice, device })
|
||||||
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
fn conv_transpose1d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
|
Reference in New Issue
Block a user