mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add 1d upsampling. (#839)
* Add 1d upsampling. * Add the interpolate functions.
This commit is contained in:
@ -727,6 +727,36 @@ impl Map1 for MaxPool2D {
|
||||
}
|
||||
}
|
||||
|
||||
struct UpsampleNearest1D(usize);
|
||||
|
||||
impl Map1 for UpsampleNearest1D {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
// TODO: Specialized implementation for the case 2*sz?
|
||||
let dst_sz = self.0;
|
||||
let (b_sz, c, src_sz) = layout.shape().dims3()?;
|
||||
let stride = layout.stride();
|
||||
let stride_sz = stride[2];
|
||||
let src_index = layout.start_offset();
|
||||
let scale_sz = src_sz as f64 / dst_sz as f64;
|
||||
let mut dst = vec![T::zero(); b_sz * c * dst_sz];
|
||||
let src_idxs = (0..dst_sz)
|
||||
.map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..b_sz {
|
||||
let dst = &mut dst[b_idx * c * dst_sz..];
|
||||
let src_index = src_index + b_idx * stride[0];
|
||||
for c_idx in 0..c {
|
||||
let dst = &mut dst[c_idx * dst_sz..];
|
||||
let src_index = src_index + c_idx * stride[1];
|
||||
for (idx, src_idx) in src_idxs.iter().enumerate() {
|
||||
dst[idx] = src[src_index + src_idx * stride_sz]
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct UpsampleNearest2D(usize, usize);
|
||||
|
||||
impl Map1 for UpsampleNearest2D {
|
||||
@ -2137,6 +2167,10 @@ impl BackendStorage for CpuStorage {
|
||||
MaxPool2D(kernel_size, stride).map(self, layout)
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
|
||||
UpsampleNearest1D(sz).map(self, layout)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||
UpsampleNearest2D(h, w).map(self, layout)
|
||||
}
|
||||
|
Reference in New Issue
Block a user