mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)
* Skeleton for the avg-pool2d and upsample-nearest2d ops. * Preliminary conv2d support.
This commit is contained in:
@ -36,7 +36,7 @@ impl Downsample2D {
|
||||
impl Downsample2D {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match &self.conv {
|
||||
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
|
||||
None => xs.avg_pool2d((2, 2), (2, 2)),
|
||||
Some(conv) => {
|
||||
if self.padding == 0 {
|
||||
let xs = xs
|
||||
@ -72,13 +72,10 @@ impl Upsample2D {
|
||||
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
|
||||
let xs = match size {
|
||||
None => {
|
||||
// The following does not work and it's tricky to pass no fixed
|
||||
// dimensions so hack our way around this.
|
||||
// xs.upsample_nearest2d(&[], Some(2.), Some(2.)
|
||||
let (_bsize, _channels, _h, _w) = xs.dims4()?;
|
||||
crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.))
|
||||
let (_bsize, _channels, h, w) = xs.dims4()?;
|
||||
xs.upsample_nearest2d(2 * h, 2 * w)?
|
||||
}
|
||||
Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None),
|
||||
Some((h, w)) => xs.upsample_nearest2d(h, w)?,
|
||||
};
|
||||
self.conv.forward(&xs)
|
||||
}
|
||||
|
@ -1,13 +1,5 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
||||
if steps < 1 {
|
||||
candle::bail!("cannot use linspace with steps {steps} <= 1")
|
||||
|
Reference in New Issue
Block a user