mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a metal kernel for col2im1d. (#2214)
* Add a metal kernel for col2im1d. * Enable the col2im variant. * Bugfix. * Revert the quantized tweak.
This commit is contained in:
@ -824,44 +824,102 @@ impl BackendStorage for MetalStorage {
|
||||
k_layout: &Layout,
|
||||
params: &ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
|
||||
let can_use_col2im = k_layout.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
let l_out = params.l_out();
|
||||
let dst_el = params.c_out * l_out * params.b_size;
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "conv_transpose1d_f32",
|
||||
DType::F16 => "conv_transpose1d_f16",
|
||||
DType::BF16 => "conv_transpose1d_bf16",
|
||||
DType::U32 => "conv_transpose1d_u32",
|
||||
DType::U8 => "conv_transpose1d_u8",
|
||||
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
|
||||
let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = layout.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = k_layout.shape().dims3()?;
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
layout.shape(),
|
||||
k_layout.shape()
|
||||
)
|
||||
}
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "col2im1d_f32",
|
||||
DType::U32 => "col2im1d_u32",
|
||||
DType::U8 => "col2im1d_u8",
|
||||
dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"),
|
||||
};
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
k_layout.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
k,
|
||||
(b_size, l_in, c_out * k_size, c_in),
|
||||
&layout.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
candle_metal_kernels::call_col2im1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
&[b_size, l_in, c_out, k_size],
|
||||
params.k_size,
|
||||
params.stride,
|
||||
BufferOffset::zero_offset(&col.buffer),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
buffer
|
||||
} else {
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "conv_transpose1d_f32",
|
||||
DType::F16 => "conv_transpose1d_f16",
|
||||
DType::BF16 => "conv_transpose1d_bf16",
|
||||
DType::U32 => "conv_transpose1d_u32",
|
||||
DType::U8 => "conv_transpose1d_u8",
|
||||
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_conv_transpose1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
params.dilation,
|
||||
params.stride,
|
||||
params.padding,
|
||||
params.output_padding,
|
||||
params.c_out,
|
||||
l_out,
|
||||
params.b_size,
|
||||
layout.dims(),
|
||||
layout.stride(),
|
||||
k_layout.dims(),
|
||||
k_layout.stride(),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&k.buffer,
|
||||
k_layout.start_offset() * k.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
buffer
|
||||
};
|
||||
candle_metal_kernels::call_conv_transpose1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
params.dilation,
|
||||
params.stride,
|
||||
params.padding,
|
||||
params.output_padding,
|
||||
params.c_out,
|
||||
l_out,
|
||||
params.b_size,
|
||||
layout.dims(),
|
||||
layout.stride(),
|
||||
k_layout.dims(),
|
||||
k_layout.stride(),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&k.buffer,
|
||||
k_layout.start_offset() * k.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||
}
|
||||
|
||||
|
@ -68,6 +68,50 @@ METAL_FUNC void im2col(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void col2im1d(
|
||||
constant size_t &dst_el,
|
||||
constant size_t &l_out,
|
||||
constant size_t &l_in,
|
||||
constant size_t &c_out,
|
||||
constant size_t &k_size,
|
||||
constant size_t &stride,
|
||||
device const T *src,
|
||||
device T *dst,
|
||||
uint dst_i [[ thread_position_in_grid ]]
|
||||
) {
|
||||
// src: (b_size, l_in, c_out, l_k)
|
||||
// dst: (b_size, c_out, l_out)
|
||||
if (dst_i >= dst_el) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t dst_s0 = c_out * l_out;
|
||||
const size_t dst_s1 = l_out;
|
||||
const size_t src_s0 = c_out * k_size * l_in;
|
||||
const size_t src_s1 = c_out * k_size;
|
||||
const size_t src_s2 = k_size;
|
||||
|
||||
size_t tmp_dst_i = dst_i;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t c_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= c_idx * dst_s1;
|
||||
const int l_out_idx = tmp_dst_i;
|
||||
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
|
||||
int l_in_idx = l_out_idx / stride;
|
||||
int k0 = l_out_idx - l_in_idx * stride;
|
||||
// l_out_idx = l_in_idx * stride + k0
|
||||
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
|
||||
if (l_in_idx < l_in) {
|
||||
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
|
||||
dst[dst_i] += src[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void im2col1d(
|
||||
constant size_t &dst_numel,
|
||||
@ -190,6 +234,21 @@ kernel void FN_NAME( \
|
||||
) { \
|
||||
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
#define COL2IM1D_OP(T, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dst_el, \
|
||||
constant size_t &l_out, \
|
||||
constant size_t &l_in, \
|
||||
constant size_t &c_out, \
|
||||
constant size_t &k_size, \
|
||||
constant size_t &stride, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
col2im1d<T>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \
|
||||
} \
|
||||
|
||||
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
@ -493,6 +552,10 @@ IM2COL_OP(uint32_t, im2col_u32)
|
||||
IM2COL_OP(bfloat, im2col_bf16)
|
||||
#endif
|
||||
|
||||
COL2IM1D_OP(float, col2im1d_f32)
|
||||
COL2IM1D_OP(uint8_t, col2im1d_u8)
|
||||
COL2IM1D_OP(uint32_t, col2im1d_u32)
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
@ -533,4 +596,4 @@ CONVT2D_OP(float, float, conv_transpose2d_f32)
|
||||
CONVT2D_OP(half, float, conv_transpose2d_f16)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CONVT1D_OP(bfloat, float, conv_transpose2d_bf16)
|
||||
#endif
|
||||
#endif
|
||||
|
@ -1651,6 +1651,39 @@ pub fn call_im2col1d_strided(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_col2im1d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let l_in = shape[1];
|
||||
let c_out = shape[2];
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
let dst_el = shape[0] * c_out * l_out;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(dst_el, l_out, l_in, c_out, k_size, stride, &input, output)
|
||||
);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_im2col_strided(
|
||||
device: &Device,
|
||||
|
Reference in New Issue
Block a user