#[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv1D { pub(crate) b_size: usize, // Maybe we should have a version without l_in as this bit depends on the input and not only on // the weights. pub(crate) l_in: usize, pub(crate) c_out: usize, pub(crate) c_in: usize, pub(crate) k_size: usize, pub(crate) padding: usize, pub(crate) stride: usize, } impl ParamsConv1D { pub(crate) fn l_out(&self) -> usize { let dilation = 1; (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec { let l_out = self.l_out(); vec![self.b_size, self.c_out, l_out] } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv2D { pub(crate) b_size: usize, pub(crate) i_h: usize, pub(crate) i_w: usize, pub(crate) k_h: usize, pub(crate) k_w: usize, pub(crate) c_out: usize, pub(crate) c_in: usize, pub(crate) padding: usize, pub(crate) stride: usize, } impl ParamsConv2D { pub(crate) fn out_h(&self) -> usize { let dilation = 1; (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1 } pub(crate) fn out_w(&self) -> usize { let dilation = 1; (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec { vec![self.b_size, self.c_out, self.out_h(), self.out_w()] } }