mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Adding upsample_nearest_2d.
This commit is contained in:
@ -108,6 +108,47 @@ METAL_FUNC void im2col1d(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void upsample_nearest2d(
|
||||
constant size_t &w_out,
|
||||
constant size_t &h_out,
|
||||
constant float &w_scale,
|
||||
constant float &h_scale,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_s,
|
||||
device const T *src,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
// src: (b_size, c_in, w_in, h_in)
|
||||
|
||||
const size_t c = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
|
||||
if (tid >= src_dims[0] * c * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: Improve this.
|
||||
const size_t b_idx = tid / (w_out * h_out * c);
|
||||
const size_t c_idx = (tid / (w_out * h_out)) % c;
|
||||
const size_t dst_w = (tid / h_out) % w_out;
|
||||
const size_t dst_h = tid % h_out;
|
||||
|
||||
size_t src_w = static_cast<size_t>(dst_w * w_scale);
|
||||
size_t src_h = static_cast<size_t>(dst_h * h_scale);
|
||||
if (src_w >= w_in) {
|
||||
src_w = w_in - 1;
|
||||
}
|
||||
if (src_h >= h_in) {
|
||||
src_h = h_in - 1;
|
||||
}
|
||||
|
||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
|
||||
dst[tid] = src[src_i];
|
||||
}
|
||||
|
||||
#define IM2COL_OP(T, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dst_numel, \
|
||||
@ -143,6 +184,21 @@ kernel void FN_NAME( \
|
||||
) { \
|
||||
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &w_out, \
|
||||
constant size_t &h_out, \
|
||||
constant float &w_scale, \
|
||||
constant float &h_scale, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
IM2COL_OP(float, im2col_f32)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
@ -151,3 +207,7 @@ IM2COL_OP(uint32_t, im2col_u32)
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
||||
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
|
||||
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
|
||||
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
|
||||
|
@ -1518,6 +1518,50 @@ pub fn call_im2col_strided(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_upsample_nearest_2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
out_w: usize,
|
||||
out_h: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let dst_el = out_w * out_h * shape[0] * shape[1];
|
||||
let scale_w = shape[2] as f32 / out_w as f32;
|
||||
let scale_h = shape[3] as f32 / out_h as f32;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
out_w,
|
||||
out_h,
|
||||
scale_w,
|
||||
scale_h,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
}
|
||||
|
Reference in New Issue
Block a user