Backprop for conv2d. (#638)

* Start adding backprop for conv2d.

* Backprop for conv2d.

* Bugfix + start adding a conv2d test.

* Conv2d backprop testing.

* More conv fixes.
This commit is contained in:
Laurent Mazare
2023-08-28 16:08:55 +01:00
committed by GitHub
parent 9137c63175
commit b292047882
4 changed files with 106 additions and 13 deletions

View File

@ -1204,7 +1204,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
let inp_x = out_x * p.stride as i32 - p.padding as i32;
let inp_y = out_y * p.stride as i32 - p.padding as i32;
for k_y in 0..p.k_h as i32 {
for k_x in 0..p.k_h as i32 {
for k_x in 0..p.k_w as i32 {
let k_index = k_y as usize * k_s2 + k_x as usize * k_s3;
let inp_y = inp_y + k_y;
let inp_x = inp_x + k_x;
@ -1215,9 +1215,11 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
let inp_y = inp_y as usize;
if inp_x < p.i_w && inp_y < p.i_h {
let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3;
let dst_index = b_idx * dst_s0 + inp_y * dst_s2 + inp_x * dst_s3;
for c_out in 0..k_s0 {
for c_in in 0..k_s1 {
let dst_index = b_idx * dst_s0
+ out_y as usize * dst_s2
+ out_x as usize * dst_s3;
for c_out in 0..p.c_out {
for c_in in 0..p.c_in {
let k_index = k_index + c_out * k_s1 + c_in * k_s0;
let dst_index = dst_index + c_out * dst_s1;
let inp_index = inp_index + c_in * inp_s1;