mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Adding upsample_nearest_2d.
This commit is contained in:
@ -959,8 +959,39 @@ impl BackendStorage for MetalStorage {
|
||||
crate::bail!("upsample_nearest1d metal")
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
crate::bail!("upsample_nearest2d metal")
|
||||
fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
||||
// let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let strides = inp_l.stride();
|
||||
if dims.len() != 4 {
|
||||
crate::bail!("unexpected input shape for upsample {dims:?}")
|
||||
}
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "upsample_nearest2d_f32",
|
||||
dtype => crate::bail!("Not implemented {dtype:?} for upsample_nearest2d, metal"),
|
||||
};
|
||||
|
||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_upsample_nearest_2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
dims,
|
||||
strides,
|
||||
out_w,
|
||||
out_h,
|
||||
&self.buffer,
|
||||
inp_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype))
|
||||
}
|
||||
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
|
Reference in New Issue
Block a user