mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Adding the convolutions (1d + 2d) to candle on metal.
This commit is contained in:
@ -782,12 +782,71 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv1D,
|
||||
layout: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
crate::bail!("conv1d metal")
|
||||
let device = self.device().clone();
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
|
||||
let stride = params.stride;
|
||||
let dilation = params.dilation;
|
||||
let padding = params.padding;
|
||||
let k_size = params.k_size;
|
||||
let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
|
||||
let dst_el = dims[0] * l_out * dims[1] * k_size;
|
||||
let dst = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv1d_im2col")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "im2col1d_f32",
|
||||
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_im2col1d_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
layout.shape().dims(),
|
||||
layout.stride(),
|
||||
(k_size, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let col = Self {
|
||||
buffer: dst,
|
||||
device,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
let l_out = params.l_out();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_size * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
@ -802,12 +861,79 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv2D,
|
||||
layout: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
crate::bail!("conv2d metal")
|
||||
let device = self.device().clone();
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
|
||||
let stride = params.stride;
|
||||
let dilation = params.dilation;
|
||||
let padding = params.padding;
|
||||
let h_k = params.k_h;
|
||||
let w_k = params.k_w;
|
||||
let h = dims[2];
|
||||
let w = dims[3];
|
||||
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
|
||||
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
|
||||
let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k;
|
||||
|
||||
let dst = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv2d_im2col")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "im2col_f32",
|
||||
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_im2col_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
layout.shape().dims(),
|
||||
layout.stride(),
|
||||
(h_k, w_k, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let col = Self {
|
||||
buffer: dst,
|
||||
device,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
let h_out = params.out_h();
|
||||
let w_out = params.out_w();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_h * params.k_w * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
|
153
candle-metal-kernels/src/conv.metal
Normal file
153
candle-metal-kernels/src/conv.metal
Normal file
@ -0,0 +1,153 @@
|
||||
template <typename T>
|
||||
METAL_FUNC void im2col(
|
||||
constant size_t &dst_numel,
|
||||
constant size_t &h_out,
|
||||
constant size_t &w_out,
|
||||
constant size_t &h_k,
|
||||
constant size_t &w_k,
|
||||
constant size_t &stride,
|
||||
constant size_t &padding,
|
||||
constant size_t &dilation,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_strides,
|
||||
device const T *src,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
|
||||
// src: (b_size, c_in, h_in, w_in)
|
||||
if (tid >= dst_numel) {
|
||||
return;
|
||||
}
|
||||
const size_t b_in = src_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t h_in = src_dims[2];
|
||||
const size_t w_in = src_dims[3];
|
||||
|
||||
const size_t dst_s4 = w_k;
|
||||
const size_t dst_s3 = h_k * dst_s4;
|
||||
const size_t dst_s2 = c_in * dst_s3;
|
||||
const size_t dst_s1 = w_out * dst_s2;
|
||||
const size_t dst_s0 = h_out * dst_s1;
|
||||
|
||||
size_t tmp_tid = tid;
|
||||
const size_t b_idx = tmp_tid / dst_s0;
|
||||
tmp_tid -= b_idx * dst_s0;
|
||||
const size_t h_idx = tmp_tid / dst_s1;
|
||||
tmp_tid -= h_idx * dst_s1;
|
||||
const size_t w_idx = tmp_tid / dst_s2;
|
||||
tmp_tid -= w_idx * dst_s2;
|
||||
const size_t c_idx = tmp_tid / dst_s3;
|
||||
tmp_tid -= c_idx * dst_s3;
|
||||
const size_t h_k_idx = tmp_tid / dst_s4;
|
||||
tmp_tid -= h_k_idx * dst_s4;
|
||||
const size_t w_k_idx = tmp_tid;
|
||||
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
|
||||
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
|
||||
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
|
||||
dst[tid] = static_cast<T>(0);
|
||||
}
|
||||
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
|
||||
dst[tid] = static_cast<T>(0);
|
||||
}
|
||||
else {
|
||||
src_h_idx -= padding;
|
||||
src_w_idx -= padding;
|
||||
const size_t src_i =
|
||||
b_idx * src_strides[0]
|
||||
+ c_idx * src_strides[1]
|
||||
+ src_h_idx * src_strides[2]
|
||||
+ src_w_idx * src_strides[3];
|
||||
dst[tid] = src[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void im2col1d(
|
||||
constant size_t &dst_numel,
|
||||
constant size_t &l_out,
|
||||
constant size_t &l_k,
|
||||
constant size_t &stride,
|
||||
constant size_t &padding,
|
||||
constant size_t &dilation,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_strides,
|
||||
device const T *src,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
// dst: (b_size, l_out, c_in, l_k)
|
||||
// src: (b_size, c_in, l_in)
|
||||
if (tid >= dst_numel) {
|
||||
return;
|
||||
}
|
||||
const size_t b_in = src_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
|
||||
const size_t dst_s2 = l_k;
|
||||
const size_t dst_s1 = c_in * dst_s2;
|
||||
const size_t dst_s0 = l_out * dst_s1;
|
||||
|
||||
size_t tmp_dst_i = tid;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t l_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= l_idx * dst_s1;
|
||||
const size_t c_idx = tmp_dst_i / dst_s2;
|
||||
tmp_dst_i -= c_idx * dst_s2;
|
||||
const size_t l_k_idx = tmp_dst_i;
|
||||
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||
dst[tid] = static_cast<T>(0);
|
||||
}
|
||||
else {
|
||||
src_l_idx -= padding;
|
||||
const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2];
|
||||
dst[tid] = src[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
#define IM2COL_OP(T, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dst_numel, \
|
||||
constant size_t &h_out, \
|
||||
constant size_t &w_out, \
|
||||
constant size_t &h_k, \
|
||||
constant size_t &w_k, \
|
||||
constant size_t &stride, \
|
||||
constant size_t &padding, \
|
||||
constant size_t &dilation, \
|
||||
constant size_t *src_dims, \
|
||||
constant size_t *src_strides, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
im2col<T>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
#define IM2COL1D_OP(T, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dst_numel, \
|
||||
constant size_t &l_out, \
|
||||
constant size_t &l_k, \
|
||||
constant size_t &stride, \
|
||||
constant size_t &padding, \
|
||||
constant size_t &dilation, \
|
||||
constant size_t *src_dims, \
|
||||
constant size_t *src_strides, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
IM2COL_OP(float, im2col_f32)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
@ -115,6 +116,7 @@ pub enum Source {
|
||||
Cast,
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
@ -225,6 +227,7 @@ impl Kernels {
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -1418,6 +1421,103 @@ pub fn call_gemm(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_im2col1d_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
|
||||
let dst_el = shape[0] * l_out * shape[1] * k_size;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
l_out,
|
||||
k_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_im2col_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
|
||||
let h = shape[2];
|
||||
let w = shape[3];
|
||||
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
|
||||
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
|
||||
|
||||
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
h_out,
|
||||
w_out,
|
||||
h_k,
|
||||
w_k,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::*;
|
||||
use half::{bf16, f16};
|
||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
@ -485,73 +485,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
read_to_vec(&dst_buffer, dst_el)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_add() {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
let options = CompileOptions::new();
|
||||
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
||||
|
||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||
let right = [1.0f32; 15];
|
||||
let index = [0u32, 4, 2];
|
||||
let ids_dim_size = index.len() as u32;
|
||||
let dst_dim_size: u32 = 15;
|
||||
let left_size: u32 = 3;
|
||||
let right_size: u32 = 3;
|
||||
|
||||
let function = library.get_function("ia_u32_f32", None).unwrap();
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let index_buffer = new_buffer(&device, &index);
|
||||
let inputs_buffer = new_buffer(&device, &left);
|
||||
let outputs_buffer = new_buffer(&device, &right);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
&index_buffer,
|
||||
&inputs_buffer,
|
||||
&outputs_buffer,
|
||||
ids_dim_size,
|
||||
left_size,
|
||||
dst_dim_size,
|
||||
right_size
|
||||
)
|
||||
);
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: right.len() as NSUInteger,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: pipeline.max_total_threads_per_threadgroup(),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
let expected = vec![
|
||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||
];
|
||||
let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f16() {
|
||||
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
||||
|
@ -64,12 +64,12 @@ kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \
|
||||
output[tid] = TYPENAME(FN(float(input[tid]))); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -78,12 +78,12 @@ kernel void FN_NAME_STRIDED( \
|
||||
constant size_t *strides, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \
|
||||
output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
|
Reference in New Issue
Block a user