mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
259 lines
7.8 KiB
Rust
259 lines
7.8 KiB
Rust
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ParamsConv1D {
|
|
pub(crate) b_size: usize,
|
|
// Maybe we should have a version without l_in as this bit depends on the input and not only on
|
|
// the weights.
|
|
pub(crate) l_in: usize,
|
|
pub(crate) c_out: usize,
|
|
pub(crate) c_in: usize,
|
|
pub(crate) k_size: usize,
|
|
pub(crate) padding: usize,
|
|
pub(crate) stride: usize,
|
|
pub(crate) dilation: usize,
|
|
}
|
|
|
|
impl ParamsConv1D {
|
|
pub(crate) fn l_out(&self) -> usize {
|
|
(self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
|
|
}
|
|
|
|
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
|
let l_out = self.l_out();
|
|
vec![self.b_size, self.c_out, l_out]
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ParamsConv2D {
|
|
pub(crate) b_size: usize,
|
|
pub(crate) i_h: usize,
|
|
pub(crate) i_w: usize,
|
|
pub(crate) k_h: usize,
|
|
pub(crate) k_w: usize,
|
|
pub(crate) c_out: usize,
|
|
pub(crate) c_in: usize,
|
|
pub(crate) padding: usize,
|
|
pub(crate) stride: usize,
|
|
pub(crate) dilation: usize,
|
|
}
|
|
|
|
impl ParamsConv2D {
|
|
pub(crate) fn out_h(&self) -> usize {
|
|
(self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
|
|
}
|
|
|
|
pub(crate) fn out_w(&self) -> usize {
|
|
(self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
|
|
}
|
|
|
|
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
|
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ParamsConvTranspose2D {
|
|
pub(crate) b_size: usize,
|
|
pub(crate) i_h: usize,
|
|
pub(crate) i_w: usize,
|
|
pub(crate) k_h: usize,
|
|
pub(crate) k_w: usize,
|
|
pub(crate) c_out_per_group: usize,
|
|
pub(crate) c_in: usize,
|
|
pub(crate) padding: usize,
|
|
pub(crate) output_padding: usize,
|
|
pub(crate) stride: usize,
|
|
pub(crate) dilation: usize,
|
|
pub(crate) groups: usize,
|
|
}
|
|
|
|
impl ParamsConvTranspose2D {
|
|
pub(crate) fn out_h(&self) -> usize {
|
|
(self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
|
|
- 2 * self.padding
|
|
}
|
|
|
|
pub(crate) fn out_w(&self) -> usize {
|
|
(self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
|
|
- 2 * self.padding
|
|
}
|
|
|
|
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
|
let c_out = self.c_out_per_group * self.groups;
|
|
vec![self.b_size, c_out, self.out_h(), self.out_w()]
|
|
}
|
|
}
|
|
|
|
impl Tensor {
|
|
fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
|
|
let storage =
|
|
self.storage()
|
|
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
|
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
|
arg,
|
|
kernel,
|
|
padding: params.padding,
|
|
stride: params.stride,
|
|
dilation: params.dilation,
|
|
});
|
|
let out_dims = params.out_dims();
|
|
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
|
}
|
|
|
|
/// Applies a 1D convolution over the input tensor.
|
|
pub fn conv1d(
|
|
&self,
|
|
kernel: &Self,
|
|
padding: usize,
|
|
stride: usize,
|
|
dilation: usize,
|
|
groups: usize,
|
|
) -> Result<Self> {
|
|
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
|
let (b_size, c_in, l_in) = self.dims3()?;
|
|
if c_in != c_in_k * groups {
|
|
Err(Error::Conv1dInvalidArgs {
|
|
inp_shape: self.shape().clone(),
|
|
k_shape: kernel.shape().clone(),
|
|
padding,
|
|
stride,
|
|
msg: "the number of in-channels on the input doesn't match the kernel size",
|
|
}
|
|
.bt())?
|
|
}
|
|
|
|
let params = ParamsConv1D {
|
|
b_size,
|
|
l_in,
|
|
c_out: c_out / groups,
|
|
c_in: c_in / groups,
|
|
k_size,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
};
|
|
if groups == 1 {
|
|
self.conv1d_single_group(kernel, ¶ms)
|
|
} else {
|
|
let blocks = self.chunk(groups, 1)?;
|
|
let kernel = kernel.chunk(groups, 0)?;
|
|
let blocks = blocks
|
|
.iter()
|
|
.zip(&kernel)
|
|
.map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms))
|
|
.collect::<Result<Vec<_>>>()?;
|
|
Tensor::cat(&blocks, 1)
|
|
}
|
|
}
|
|
|
|
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
|
let storage =
|
|
self.storage()
|
|
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
|
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
|
|
arg,
|
|
kernel,
|
|
padding: params.padding,
|
|
stride: params.stride,
|
|
dilation: params.dilation,
|
|
});
|
|
let out_dims = params.out_dims();
|
|
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
|
}
|
|
|
|
/// Applies a 2D convolution over the input tensor.
|
|
pub fn conv2d(
|
|
&self,
|
|
kernel: &Self,
|
|
padding: usize,
|
|
stride: usize,
|
|
dilation: usize,
|
|
groups: usize,
|
|
) -> Result<Self> {
|
|
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
|
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
|
if c_in != c_in_k * groups {
|
|
crate::bail!(
|
|
"in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
|
|
)
|
|
}
|
|
let params = ParamsConv2D {
|
|
b_size,
|
|
i_h,
|
|
i_w,
|
|
k_h,
|
|
k_w,
|
|
c_out: c_out / groups,
|
|
c_in: c_in / groups,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
};
|
|
if groups == 1 {
|
|
self.conv2d_single_group(kernel, ¶ms)
|
|
} else {
|
|
let blocks = self.chunk(groups, 1)?;
|
|
let kernel = kernel.chunk(groups, 0)?;
|
|
let blocks = blocks
|
|
.iter()
|
|
.zip(&kernel)
|
|
.map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms))
|
|
.collect::<Result<Vec<_>>>()?;
|
|
Tensor::cat(&blocks, 1)
|
|
}
|
|
}
|
|
|
|
/// Applies a 2D transposed convolution over the input tensor.
|
|
pub fn conv_transpose2d(
|
|
&self,
|
|
kernel: &Self,
|
|
padding: usize,
|
|
output_padding: usize,
|
|
stride: usize,
|
|
dilation: usize,
|
|
groups: usize,
|
|
) -> Result<Self> {
|
|
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
|
let (c_in_k, c_out_per_group, k_h, k_w) = kernel.dims4()?;
|
|
if c_in != c_in_k {
|
|
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
|
}
|
|
if c_in % groups != 0 {
|
|
crate::bail!("in_channel {c_in} must be divisible by groups {groups}")
|
|
}
|
|
let params = ParamsConvTranspose2D {
|
|
b_size,
|
|
i_h,
|
|
i_w,
|
|
k_h,
|
|
k_w,
|
|
c_out_per_group,
|
|
c_in,
|
|
padding,
|
|
output_padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
};
|
|
let storage = self.storage().conv_transpose2d(
|
|
self.layout(),
|
|
&kernel.storage(),
|
|
kernel.layout(),
|
|
¶ms,
|
|
)?;
|
|
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
|
|
arg,
|
|
kernel,
|
|
padding: params.padding,
|
|
output_padding: params.output_padding,
|
|
stride: params.stride,
|
|
dilation: params.dilation,
|
|
groups: params.groups,
|
|
});
|
|
let out_dims = params.out_dims();
|
|
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
|
}
|
|
}
|