mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add the conv-transpose1d op. (#1251)
* Skeleton structure for conv-transpose1d. * CPU implementation for conv-transpose1d.
This commit is contained in:
@ -39,6 +39,14 @@ pub trait BackendStorage: Sized {
|
|||||||
_params: &crate::conv::ParamsConv1D,
|
_params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_l: &Layout,
|
_l: &Layout,
|
||||||
|
@ -57,6 +57,11 @@ impl Tensor {
|
|||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
..
|
..
|
||||||
}
|
}
|
||||||
|
| Op::ConvTranspose1D {
|
||||||
|
arg: lhs,
|
||||||
|
kernel: rhs,
|
||||||
|
..
|
||||||
|
}
|
||||||
| Op::Conv2D {
|
| Op::Conv2D {
|
||||||
arg: lhs,
|
arg: lhs,
|
||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
@ -247,6 +252,9 @@ impl Tensor {
|
|||||||
};
|
};
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
}
|
}
|
||||||
|
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
|
op: "conv-transpose1d",
|
||||||
|
})?,
|
||||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "conv-transpose2d",
|
op: "conv-transpose2d",
|
||||||
})?,
|
})?,
|
||||||
|
@ -25,6 +25,33 @@ impl ParamsConv1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct ParamsConvTranspose1D {
|
||||||
|
pub(crate) b_size: usize,
|
||||||
|
pub(crate) l_in: usize,
|
||||||
|
pub(crate) c_out: usize,
|
||||||
|
pub(crate) c_in: usize,
|
||||||
|
pub(crate) k_size: usize,
|
||||||
|
pub(crate) padding: usize,
|
||||||
|
pub(crate) output_padding: usize,
|
||||||
|
pub(crate) stride: usize,
|
||||||
|
pub(crate) dilation: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParamsConvTranspose1D {
|
||||||
|
pub(crate) fn l_out(&self) -> usize {
|
||||||
|
(self.l_in - 1) * self.stride - 2 * self.padding
|
||||||
|
+ self.dilation * (self.k_size - 1)
|
||||||
|
+ self.output_padding
|
||||||
|
+ 1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||||
|
let l_out = self.l_out();
|
||||||
|
vec![self.b_size, self.c_out, l_out]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum CudnnFwdAlgo {
|
pub enum CudnnFwdAlgo {
|
||||||
ImplicitGemm,
|
ImplicitGemm,
|
||||||
@ -160,6 +187,49 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies a 1D transposed convolution over the input tensor.
|
||||||
|
pub fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
kernel: &Self,
|
||||||
|
padding: usize,
|
||||||
|
output_padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||||
|
let (b_size, c_in, l_in) = self.dims3()?;
|
||||||
|
if c_in != c_in_k {
|
||||||
|
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||||
|
}
|
||||||
|
let params = ParamsConvTranspose1D {
|
||||||
|
b_size,
|
||||||
|
l_in,
|
||||||
|
k_size,
|
||||||
|
c_out,
|
||||||
|
c_in,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
};
|
||||||
|
let storage = self.storage().conv_transpose1d(
|
||||||
|
self.layout(),
|
||||||
|
&kernel.storage(),
|
||||||
|
kernel.layout(),
|
||||||
|
¶ms,
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||||
|
arg,
|
||||||
|
kernel,
|
||||||
|
padding: params.padding,
|
||||||
|
output_padding: params.output_padding,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
});
|
||||||
|
let out_dims = params.out_dims();
|
||||||
|
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
|
@ -1256,6 +1256,74 @@ impl Map1 for Im2Col {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||||
|
|
||||||
|
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||||
|
const OP: &'static str = "conv_transpose1d";
|
||||||
|
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||||
|
let p = self.0;
|
||||||
|
let inp = &inp[inp_l.start_offset()..];
|
||||||
|
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||||
|
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||||
|
let l_out = p.l_out();
|
||||||
|
|
||||||
|
// Output shape: [b_size, c_out, l_out].
|
||||||
|
let dst_elems = p.c_out * l_out * p.b_size;
|
||||||
|
let dst = vec![T::zero(); dst_elems];
|
||||||
|
let dst_s0 = p.c_out * l_out;
|
||||||
|
let dst_s1 = l_out;
|
||||||
|
let dst_s2 = 1;
|
||||||
|
|
||||||
|
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||||
|
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||||
|
let cont_s0 = p.l_in * p.c_in;
|
||||||
|
let cont_s1 = p.c_in;
|
||||||
|
for b_idx in 0..p.b_size {
|
||||||
|
for l_idx in 0..p.l_in {
|
||||||
|
for c_idx in 0..p.c_in {
|
||||||
|
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
|
||||||
|
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
|
||||||
|
inp_cont[dst_idx] = inp[src_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k_idx in 0..p.k_size {
|
||||||
|
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||||
|
let k_cont = (0..p.c_in)
|
||||||
|
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for b_idx in 0..p.b_size {
|
||||||
|
for l_idx in 0..p.l_in {
|
||||||
|
let out_idx = l_idx * p.stride + k_idx * p.dilation;
|
||||||
|
if out_idx < p.padding {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let out_idx = out_idx - p.padding;
|
||||||
|
if out_idx < l_out {
|
||||||
|
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
|
||||||
|
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
|
||||||
|
let mut d = T::zero();
|
||||||
|
unsafe {
|
||||||
|
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
||||||
|
}
|
||||||
|
let dst_p = dst.as_ptr();
|
||||||
|
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
||||||
|
// parallelise the different tasks so no two threads can try to
|
||||||
|
// write at the same location.
|
||||||
|
unsafe {
|
||||||
|
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||||
|
*ptr += d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||||
|
|
||||||
impl<'a> Map2 for Conv2D<'a> {
|
impl<'a> Map2 for Conv2D<'a> {
|
||||||
@ -2435,6 +2503,16 @@ impl BackendStorage for CpuStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
|
@ -1808,6 +1808,16 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
#[cfg(not(feature = "cudnn"))]
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
|
@ -79,6 +79,16 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
|
@ -90,6 +90,16 @@ pub enum Op {
|
|||||||
dilation: usize,
|
dilation: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ConvTranspose1D {
|
||||||
|
arg: Tensor,
|
||||||
|
kernel: Tensor,
|
||||||
|
padding: usize,
|
||||||
|
output_padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
},
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
Conv2D {
|
Conv2D {
|
||||||
arg: Tensor,
|
arg: Tensor,
|
||||||
|
@ -279,6 +279,33 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(kernel, "conv-transpose1d")?;
|
||||||
|
self.same_dtype(kernel, "conv-transpose1d")?;
|
||||||
|
match (self, &kernel) {
|
||||||
|
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||||
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Cpu(s))
|
||||||
|
}
|
||||||
|
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||||
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Cuda(s))
|
||||||
|
}
|
||||||
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "conv-transpose1d",
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn conv2d(
|
pub(crate) fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
|
Reference in New Issue
Block a user