Adding upsample_nearest_2d.

This commit is contained in:
Nicolas Patry
2023-12-25 14:25:19 +01:00
parent 1505d85276
commit 13a5d15ebc
3 changed files with 137 additions and 2 deletions

View File

@ -959,8 +959,39 @@ impl BackendStorage for MetalStorage {
crate::bail!("upsample_nearest1d metal")
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
crate::bail!("upsample_nearest2d metal")
fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
// let inp = &inp.slice(inp_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let strides = inp_l.stride();
if dims.len() != 4 {
crate::bail!("unexpected input shape for upsample {dims:?}")
}
let name = match self.dtype {
DType::F32 => "upsample_nearest2d_f32",
dtype => crate::bail!("Not implemented {dtype:?} for upsample_nearest2d, metal"),
};
let dst_el = out_w * out_h * dims[0] * dims[1];
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_upsample_nearest_2d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
dims,
strides,
out_w,
out_h,
&self.buffer,
inp_l.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), self.dtype))
}
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {