Files
candle/candle-core/src/conv.rs
Laurent Mazare b5bb5e056d Add more conv2d support. (#340)
* Add more conv2d support.

* Conv2d cpu work.

* Conv2d output shape.
2023-08-08 06:04:32 +01:00

57 lines
1.6 KiB
Rust

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
pub(crate) b_size: Option<usize>,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
pub(crate) l_in: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self) -> usize {
let dilation = 1;
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
match self.b_size {
None => vec![self.c_out, l_out],
Some(n) => vec![n, self.c_out, l_out],
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv2D {
pub(crate) b_size: usize,
pub(crate) i_h: usize,
pub(crate) i_w: usize,
pub(crate) k_h: usize,
pub(crate) k_w: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
let dilation = 1;
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
let dilation = 1;
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}