mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add some conv2d tests. (#347)
* Add some conv2d tests. * Add a simpler conv2d test. * More conv2d testing + bugfix. * Add a todo.
This commit is contained in:
@ -1037,7 +1037,7 @@ struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
const OP: &'static str = "conv2d";
|
||||
fn f<T: 'static + num_traits::NumAssign + Copy>(
|
||||
fn f<T: 'static + num_traits::NumAssign + Copy + std::fmt::Display>(
|
||||
&self,
|
||||
inp: &[T],
|
||||
inp_l: &Layout,
|
||||
@ -1063,11 +1063,13 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
let dst_idx = dst_idx + dst_w;
|
||||
let mut d = T::zero();
|
||||
for offset_h in 0..p.k_h {
|
||||
let src_h_plus = p.stride * dst_h + offset_h;
|
||||
// TODO: Handle the case where padding is larger than p.k_h / 2.
|
||||
let src_h_plus = p.stride * dst_h + offset_h + p.k_h / 2 - p.padding;
|
||||
if p.k_h / 2 <= src_h_plus && src_h_plus < p.k_h / 2 + p.i_h {
|
||||
let src_h = src_h_plus - p.k_h / 2;
|
||||
for offset_w in 0..p.k_w {
|
||||
let src_w_plus = p.stride * dst_w + offset_w;
|
||||
let src_w_plus =
|
||||
p.stride * dst_w + offset_w + p.k_w / 2 - p.padding;
|
||||
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
|
||||
if p.k_w / 2 <= src_w_plus && src_w_plus < p.k_w / 2 + p.i_w {
|
||||
let src_w = src_w_plus - p.k_w / 2;
|
||||
|
Reference in New Issue
Block a user