mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Fix the cpu kernel for conv-transpose. (#643)
This commit is contained in:
@ -1199,25 +1199,23 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
let dst_s2 = out_w;
|
||||
let dst_s3 = 1;
|
||||
for b_idx in 0..p.b_size {
|
||||
for out_y in 0..out_h as i32 {
|
||||
for out_x in 0..out_w as i32 {
|
||||
let inp_x = out_x * p.stride as i32 - p.padding as i32;
|
||||
let inp_y = out_y * p.stride as i32 - p.padding as i32;
|
||||
for inp_y in 0..p.i_h {
|
||||
for inp_x in 0..p.i_w {
|
||||
let out_x = (inp_x * p.stride) as i32 - p.padding as i32;
|
||||
let out_y = (inp_y * p.stride) as i32 - p.padding as i32;
|
||||
for k_y in 0..p.k_h as i32 {
|
||||
for k_x in 0..p.k_w as i32 {
|
||||
let k_index = k_y as usize * k_s2 + k_x as usize * k_s3;
|
||||
let inp_y = inp_y + k_y;
|
||||
let inp_x = inp_x + k_x;
|
||||
if inp_x < 0 || inp_y < 0 {
|
||||
let out_y = out_y + k_y;
|
||||
let out_x = out_x + k_x;
|
||||
if out_x < 0 || out_y < 0 {
|
||||
continue;
|
||||
}
|
||||
let inp_x = inp_x as usize;
|
||||
let inp_y = inp_y as usize;
|
||||
if inp_x < p.i_w && inp_y < p.i_h {
|
||||
let out_x = out_x as usize;
|
||||
let out_y = out_y as usize;
|
||||
if out_x < out_w && out_y < out_h {
|
||||
let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3;
|
||||
let dst_index = b_idx * dst_s0
|
||||
+ out_y as usize * dst_s2
|
||||
+ out_x as usize * dst_s3;
|
||||
let dst_index = b_idx * dst_s0 + out_y * dst_s2 + out_x * dst_s3;
|
||||
for c_out in 0..p.c_out {
|
||||
for c_in in 0..p.c_in {
|
||||
let k_index = k_index + c_out * k_s1 + c_in * k_s0;
|
||||
|
Reference in New Issue
Block a user