mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Backprop for conv2d. (#638)
* Start adding backprop for conv2d. * Backprop for conv2d. * Bugfix + start adding a conv2d test. * Conv2d backprop testing. * More conv fixes.
This commit is contained in:
@ -192,7 +192,30 @@ impl Tensor {
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
|
||||
Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
} => {
|
||||
// The output height for conv_transpose2d is:
|
||||
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
|
||||
let grad_h = grad.dim(2)?;
|
||||
let k_h = kernel.dim(2)?;
|
||||
let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
|
||||
let out_padding = arg.dim(2)? - out_size;
|
||||
let grad_arg =
|
||||
grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = arg
|
||||
.transpose(0, 1)?
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose2d",
|
||||
})?,
|
||||
|
@ -71,18 +71,14 @@ pub struct ParamsConvTranspose2D {
|
||||
impl ParamsConvTranspose2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.i_h - 1) * self.stride - 2 * self.padding
|
||||
+ dilation * (self.k_h - 1)
|
||||
+ self.output_padding
|
||||
+ 1
|
||||
(self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.i_w - 1) * self.stride - 2 * self.padding
|
||||
+ dilation * (self.k_w - 1)
|
||||
+ self.output_padding
|
||||
+ 1
|
||||
(self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
|
@ -1204,7 +1204,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
let inp_x = out_x * p.stride as i32 - p.padding as i32;
|
||||
let inp_y = out_y * p.stride as i32 - p.padding as i32;
|
||||
for k_y in 0..p.k_h as i32 {
|
||||
for k_x in 0..p.k_h as i32 {
|
||||
for k_x in 0..p.k_w as i32 {
|
||||
let k_index = k_y as usize * k_s2 + k_x as usize * k_s3;
|
||||
let inp_y = inp_y + k_y;
|
||||
let inp_x = inp_x + k_x;
|
||||
@ -1215,9 +1215,11 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
let inp_y = inp_y as usize;
|
||||
if inp_x < p.i_w && inp_y < p.i_h {
|
||||
let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3;
|
||||
let dst_index = b_idx * dst_s0 + inp_y * dst_s2 + inp_x * dst_s3;
|
||||
for c_out in 0..k_s0 {
|
||||
for c_in in 0..k_s1 {
|
||||
let dst_index = b_idx * dst_s0
|
||||
+ out_y as usize * dst_s2
|
||||
+ out_x as usize * dst_s3;
|
||||
for c_out in 0..p.c_out {
|
||||
for c_in in 0..p.c_in {
|
||||
let k_index = k_index + c_out * k_s1 + c_in * k_s0;
|
||||
let dst_index = dst_index + c_out * dst_s1;
|
||||
let inp_index = inp_index + c_in * inp_s1;
|
||||
|
Reference in New Issue
Block a user