Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)

* Skeleton for the avg-pool2d and upsample-nearest2d ops.

* Preliminary conv2d support.
This commit is contained in:
Laurent Mazare
2023-08-07 17:15:38 +02:00
committed by GitHub
parent f53a333ea9
commit 2345b8ce3f
7 changed files with 88 additions and 17 deletions

View File

@ -55,6 +55,11 @@ impl Tensor {
kernel: rhs,
..
}
| Op::Conv2D {
arg: lhs,
kernel: rhs,
..
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
@ -81,6 +86,8 @@ impl Tensor {
}
}
Op::Reshape(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
@ -163,6 +170,11 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;