mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge pull request #1461 from huggingface/metal-conv
Adding the convolutions (1d + 2d) to candle on metal.
This commit is contained in:
@ -782,12 +782,72 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
&self,
|
&self,
|
||||||
_l: &Layout,
|
layout: &Layout,
|
||||||
_kernel: &Self,
|
kernel: &Self,
|
||||||
_kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
_params: &ParamsConv1D,
|
params: &ParamsConv1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
crate::bail!("conv1d metal")
|
let device = self.device().clone();
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let strides = layout.stride();
|
||||||
|
|
||||||
|
let stride = params.stride;
|
||||||
|
let dilation = params.dilation;
|
||||||
|
let padding = params.padding;
|
||||||
|
let k_size = params.k_size;
|
||||||
|
let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
|
||||||
|
let dst_el = dims[0] * l_out * dims[1] * k_size;
|
||||||
|
let dst = self
|
||||||
|
.device
|
||||||
|
.new_buffer(dst_el, self.dtype, "conv1d_im2col")?;
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
let name = match self.dtype {
|
||||||
|
DType::F32 => "im2col1d_f32",
|
||||||
|
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_im2col1d_strided(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
layout.shape().dims(),
|
||||||
|
strides,
|
||||||
|
(k_size, stride, padding, dilation),
|
||||||
|
&self.buffer,
|
||||||
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
let col = Self {
|
||||||
|
buffer: dst,
|
||||||
|
device,
|
||||||
|
dtype: self.dtype,
|
||||||
|
};
|
||||||
|
let l_out = params.l_out();
|
||||||
|
let b = params.b_size;
|
||||||
|
let n = params.c_out;
|
||||||
|
let k = params.k_size * params.c_in;
|
||||||
|
let m = l_out;
|
||||||
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
|
let res = if kernel_l.is_contiguous() {
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
} else {
|
||||||
|
// Make the kernel contiguous if not already the case.
|
||||||
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
};
|
||||||
|
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||||
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
fn conv_transpose1d(
|
||||||
@ -802,12 +862,79 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_l: &Layout,
|
layout: &Layout,
|
||||||
_kernel: &Self,
|
kernel: &Self,
|
||||||
_kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
_params: &ParamsConv2D,
|
params: &ParamsConv2D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
crate::bail!("conv2d metal")
|
let device = self.device().clone();
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
|
||||||
|
let stride = params.stride;
|
||||||
|
let dilation = params.dilation;
|
||||||
|
let padding = params.padding;
|
||||||
|
let h_k = params.k_h;
|
||||||
|
let w_k = params.k_w;
|
||||||
|
let h = dims[2];
|
||||||
|
let w = dims[3];
|
||||||
|
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
|
||||||
|
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
|
||||||
|
let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k;
|
||||||
|
|
||||||
|
let dst = self
|
||||||
|
.device
|
||||||
|
.new_buffer(dst_el, self.dtype, "conv2d_im2col")?;
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
let name = match self.dtype {
|
||||||
|
DType::F32 => "im2col_f32",
|
||||||
|
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_im2col_strided(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
layout.shape().dims(),
|
||||||
|
layout.stride(),
|
||||||
|
(h_k, w_k, stride, padding, dilation),
|
||||||
|
&self.buffer,
|
||||||
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
let col = Self {
|
||||||
|
buffer: dst,
|
||||||
|
device,
|
||||||
|
dtype: self.dtype,
|
||||||
|
};
|
||||||
|
let h_out = params.out_h();
|
||||||
|
let w_out = params.out_w();
|
||||||
|
let b = params.b_size;
|
||||||
|
let n = params.c_out;
|
||||||
|
let k = params.k_h * params.k_w * params.c_in;
|
||||||
|
let m = h_out * w_out;
|
||||||
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
|
let res = if kernel_l.is_contiguous() {
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
} else {
|
||||||
|
// Make the kernel contiguous if not already the case.
|
||||||
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
};
|
||||||
|
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.transpose(1, 3)?;
|
||||||
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose2d(
|
fn conv_transpose2d(
|
||||||
|
153
candle-metal-kernels/src/conv.metal
Normal file
153
candle-metal-kernels/src/conv.metal
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
template <typename T>
|
||||||
|
METAL_FUNC void im2col(
|
||||||
|
constant size_t &dst_numel,
|
||||||
|
constant size_t &h_out,
|
||||||
|
constant size_t &w_out,
|
||||||
|
constant size_t &h_k,
|
||||||
|
constant size_t &w_k,
|
||||||
|
constant size_t &stride,
|
||||||
|
constant size_t &padding,
|
||||||
|
constant size_t &dilation,
|
||||||
|
constant size_t *src_dims,
|
||||||
|
constant size_t *src_strides,
|
||||||
|
device const T *src,
|
||||||
|
device T *dst,
|
||||||
|
uint tid [[ thread_position_in_grid ]]
|
||||||
|
) {
|
||||||
|
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
|
||||||
|
// src: (b_size, c_in, h_in, w_in)
|
||||||
|
if (tid >= dst_numel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t b_in = src_dims[0];
|
||||||
|
const size_t c_in = src_dims[1];
|
||||||
|
const size_t h_in = src_dims[2];
|
||||||
|
const size_t w_in = src_dims[3];
|
||||||
|
|
||||||
|
const size_t dst_s4 = w_k;
|
||||||
|
const size_t dst_s3 = h_k * dst_s4;
|
||||||
|
const size_t dst_s2 = c_in * dst_s3;
|
||||||
|
const size_t dst_s1 = w_out * dst_s2;
|
||||||
|
const size_t dst_s0 = h_out * dst_s1;
|
||||||
|
|
||||||
|
size_t tmp_tid = tid;
|
||||||
|
const size_t b_idx = tmp_tid / dst_s0;
|
||||||
|
tmp_tid -= b_idx * dst_s0;
|
||||||
|
const size_t h_idx = tmp_tid / dst_s1;
|
||||||
|
tmp_tid -= h_idx * dst_s1;
|
||||||
|
const size_t w_idx = tmp_tid / dst_s2;
|
||||||
|
tmp_tid -= w_idx * dst_s2;
|
||||||
|
const size_t c_idx = tmp_tid / dst_s3;
|
||||||
|
tmp_tid -= c_idx * dst_s3;
|
||||||
|
const size_t h_k_idx = tmp_tid / dst_s4;
|
||||||
|
tmp_tid -= h_k_idx * dst_s4;
|
||||||
|
const size_t w_k_idx = tmp_tid;
|
||||||
|
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
|
||||||
|
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
|
||||||
|
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
|
||||||
|
dst[tid] = static_cast<T>(0);
|
||||||
|
}
|
||||||
|
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
|
||||||
|
dst[tid] = static_cast<T>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
src_h_idx -= padding;
|
||||||
|
src_w_idx -= padding;
|
||||||
|
const size_t src_i =
|
||||||
|
b_idx * src_strides[0]
|
||||||
|
+ c_idx * src_strides[1]
|
||||||
|
+ src_h_idx * src_strides[2]
|
||||||
|
+ src_w_idx * src_strides[3];
|
||||||
|
dst[tid] = src[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC void im2col1d(
|
||||||
|
constant size_t &dst_numel,
|
||||||
|
constant size_t &l_out,
|
||||||
|
constant size_t &l_k,
|
||||||
|
constant size_t &stride,
|
||||||
|
constant size_t &padding,
|
||||||
|
constant size_t &dilation,
|
||||||
|
constant size_t *src_dims,
|
||||||
|
constant size_t *src_strides,
|
||||||
|
device const T *src,
|
||||||
|
device T *dst,
|
||||||
|
uint tid [[ thread_position_in_grid ]]
|
||||||
|
) {
|
||||||
|
// dst: (b_size, l_out, c_in, l_k)
|
||||||
|
// src: (b_size, c_in, l_in)
|
||||||
|
if (tid >= dst_numel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t b_in = src_dims[0];
|
||||||
|
const size_t c_in = src_dims[1];
|
||||||
|
const size_t l_in = src_dims[2];
|
||||||
|
|
||||||
|
const size_t dst_s2 = l_k;
|
||||||
|
const size_t dst_s1 = c_in * dst_s2;
|
||||||
|
const size_t dst_s0 = l_out * dst_s1;
|
||||||
|
|
||||||
|
size_t tmp_dst_i = tid;
|
||||||
|
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||||
|
tmp_dst_i -= b_idx * dst_s0;
|
||||||
|
const size_t l_idx = tmp_dst_i / dst_s1;
|
||||||
|
tmp_dst_i -= l_idx * dst_s1;
|
||||||
|
const size_t c_idx = tmp_dst_i / dst_s2;
|
||||||
|
tmp_dst_i -= c_idx * dst_s2;
|
||||||
|
const size_t l_k_idx = tmp_dst_i;
|
||||||
|
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||||
|
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||||
|
dst[tid] = static_cast<T>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
src_l_idx -= padding;
|
||||||
|
const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2];
|
||||||
|
dst[tid] = src[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define IM2COL_OP(T, FN_NAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &dst_numel, \
|
||||||
|
constant size_t &h_out, \
|
||||||
|
constant size_t &w_out, \
|
||||||
|
constant size_t &h_k, \
|
||||||
|
constant size_t &w_k, \
|
||||||
|
constant size_t &stride, \
|
||||||
|
constant size_t &padding, \
|
||||||
|
constant size_t &dilation, \
|
||||||
|
constant size_t *src_dims, \
|
||||||
|
constant size_t *src_strides, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
im2col<T>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
#define IM2COL1D_OP(T, FN_NAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &dst_numel, \
|
||||||
|
constant size_t &l_out, \
|
||||||
|
constant size_t &l_k, \
|
||||||
|
constant size_t &stride, \
|
||||||
|
constant size_t &padding, \
|
||||||
|
constant size_t &dilation, \
|
||||||
|
constant size_t *src_dims, \
|
||||||
|
constant size_t *src_strides, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
IM2COL_OP(float, im2col_f32)
|
||||||
|
IM2COL_OP(uint8_t, im2col_u8)
|
||||||
|
IM2COL_OP(uint32_t, im2col_u32)
|
||||||
|
|
||||||
|
IM2COL1D_OP(float, im2col1d_f32)
|
||||||
|
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||||
|
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal");
|
|||||||
const TERNARY: &str = include_str!("ternary.metal");
|
const TERNARY: &str = include_str!("ternary.metal");
|
||||||
const CAST: &str = include_str!("cast.metal");
|
const CAST: &str = include_str!("cast.metal");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
|
const CONV: &str = include_str!("conv.metal");
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
|
|
||||||
/// Most kernels apply similarly across the tensors
|
/// Most kernels apply similarly across the tensors
|
||||||
@ -115,6 +116,7 @@ pub enum Source {
|
|||||||
Cast,
|
Cast,
|
||||||
Reduce,
|
Reduce,
|
||||||
Mfa,
|
Mfa,
|
||||||
|
Conv,
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! ops{
|
macro_rules! ops{
|
||||||
@ -225,6 +227,7 @@ impl Kernels {
|
|||||||
Source::Indexing => INDEXING,
|
Source::Indexing => INDEXING,
|
||||||
Source::Cast => CAST,
|
Source::Cast => CAST,
|
||||||
Source::Reduce => REDUCE,
|
Source::Reduce => REDUCE,
|
||||||
|
Source::Conv => CONV,
|
||||||
Source::Mfa => panic!("Invalid lib"),
|
Source::Mfa => panic!("Invalid lib"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1298,7 +1301,7 @@ pub fn call_gemm(
|
|||||||
let fused_activation = false;
|
let fused_activation = false;
|
||||||
let fused_bias = false;
|
let fused_bias = false;
|
||||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||||
let m_simd = 16;
|
let m_simd = 8;
|
||||||
let n_simd = 8;
|
let n_simd = 8;
|
||||||
let k_simd = 64;
|
let k_simd = 64;
|
||||||
let m_splits = 1;
|
let m_splits = 1;
|
||||||
@ -1307,7 +1310,7 @@ pub fn call_gemm(
|
|||||||
} else {
|
} else {
|
||||||
let m_simd = 40;
|
let m_simd = 40;
|
||||||
let n_simd = 40;
|
let n_simd = 40;
|
||||||
let k_simd = 8;
|
let k_simd = 32;
|
||||||
let m_splits = 1;
|
let m_splits = 1;
|
||||||
let n_splits = 1;
|
let n_splits = 1;
|
||||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||||
@ -1418,6 +1421,103 @@ pub fn call_gemm(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_im2col1d_strided(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
shape: &[usize],
|
||||||
|
strides: &[usize],
|
||||||
|
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
|
||||||
|
input: &Buffer,
|
||||||
|
input_offset: usize,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||||
|
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
|
||||||
|
let dst_el = shape[0] * l_out * shape[1] * k_size;
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
dst_el,
|
||||||
|
l_out,
|
||||||
|
k_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
(input, input_offset),
|
||||||
|
output
|
||||||
|
)
|
||||||
|
);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
|
encoder.end_encoding();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_im2col_strided(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
shape: &[usize],
|
||||||
|
strides: &[usize],
|
||||||
|
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
|
||||||
|
input: &Buffer,
|
||||||
|
input_offset: usize,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||||
|
|
||||||
|
let h = shape[2];
|
||||||
|
let w = shape[3];
|
||||||
|
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
|
||||||
|
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
|
||||||
|
|
||||||
|
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
dst_el,
|
||||||
|
h_out,
|
||||||
|
w_out,
|
||||||
|
h_k,
|
||||||
|
w_k,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
(input, input_offset),
|
||||||
|
output
|
||||||
|
)
|
||||||
|
);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
|
encoder.end_encoding();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||||
((m + b - 1) / b) as NSUInteger
|
((m + b - 1) / b) as NSUInteger
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
use metal::{Device, MTLResourceOptions};
|
||||||
|
|
||||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||||
let ptr = buffer.contents() as *const T;
|
let ptr = buffer.contents() as *const T;
|
||||||
@ -485,73 +485,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
read_to_vec(&dst_buffer, dst_el)
|
read_to_vec(&dst_buffer, dst_el)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn index_add() {
|
|
||||||
let device = Device::system_default().expect("no device found");
|
|
||||||
|
|
||||||
let options = CompileOptions::new();
|
|
||||||
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
|
||||||
|
|
||||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
|
||||||
let right = [1.0f32; 15];
|
|
||||||
let index = [0u32, 4, 2];
|
|
||||||
let ids_dim_size = index.len() as u32;
|
|
||||||
let dst_dim_size: u32 = 15;
|
|
||||||
let left_size: u32 = 3;
|
|
||||||
let right_size: u32 = 3;
|
|
||||||
|
|
||||||
let function = library.get_function("ia_u32_f32", None).unwrap();
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
let index_buffer = new_buffer(&device, &index);
|
|
||||||
let inputs_buffer = new_buffer(&device, &left);
|
|
||||||
let outputs_buffer = new_buffer(&device, &right);
|
|
||||||
|
|
||||||
set_params!(
|
|
||||||
encoder,
|
|
||||||
(
|
|
||||||
&index_buffer,
|
|
||||||
&inputs_buffer,
|
|
||||||
&outputs_buffer,
|
|
||||||
ids_dim_size,
|
|
||||||
left_size,
|
|
||||||
dst_dim_size,
|
|
||||||
right_size
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
|
||||||
width: right.len() as NSUInteger,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
|
||||||
width: pipeline.max_total_threads_per_threadgroup(),
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
let expected = vec![
|
|
||||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
|
||||||
];
|
|
||||||
let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len());
|
|
||||||
assert_eq!(result, expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cos_f16() {
|
fn cos_f16() {
|
||||||
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
||||||
|
@ -64,12 +64,12 @@ kernel void FN_NAME( \
|
|||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
device const TYPENAME *input, \
|
device const TYPENAME *input, \
|
||||||
device TYPENAME *output, \
|
device TYPENAME *output, \
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (thread_position_in_grid >= dim) { \
|
if (tid >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \
|
output[tid] = TYPENAME(FN(float(input[tid]))); \
|
||||||
}\
|
}\
|
||||||
kernel void FN_NAME_STRIDED( \
|
kernel void FN_NAME_STRIDED( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -78,12 +78,12 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
constant size_t *strides, \
|
constant size_t *strides, \
|
||||||
device const TYPENAME *input, \
|
device const TYPENAME *input, \
|
||||||
device TYPENAME *output, \
|
device TYPENAME *output, \
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (thread_position_in_grid >= dim) { \
|
if (tid >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \
|
output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define UNARY_OP(NAME) \
|
#define UNARY_OP(NAME) \
|
||||||
|
Reference in New Issue
Block a user