mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Finished the unary
- Added proper kernel type check (through modules + macro) - split contiguous and strided into 2 different kernels - Verified on long range + strided values.
This commit is contained in:
@ -1,12 +1,48 @@
|
|||||||
use metal::{Buffer, CompileOptions, Device, Function, Library};
|
use metal::{
|
||||||
|
Buffer, CommandBuffer, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
|
||||||
|
MTLSize,
|
||||||
|
};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
pub const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &str = include_str!("affine.metal");
|
||||||
pub const INDEXING: &str = include_str!("indexing.metal");
|
const INDEXING: &str = include_str!("indexing.metal");
|
||||||
pub const UNARY: &str = include_str!("unary.metal");
|
const UNARY: &str = include_str!("unary.metal");
|
||||||
|
|
||||||
|
macro_rules! unary{
|
||||||
|
($($name:ident),+) => {
|
||||||
|
|
||||||
|
pub mod contiguous {
|
||||||
|
pub struct Kernel(pub &'static str);
|
||||||
|
$(
|
||||||
|
pub mod $name {
|
||||||
|
use super::Kernel;
|
||||||
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
|
||||||
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
|
||||||
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
||||||
|
}
|
||||||
|
)+
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod strided {
|
||||||
|
pub struct Kernel(pub &'static str);
|
||||||
|
$(
|
||||||
|
pub mod $name {
|
||||||
|
use super::Kernel;
|
||||||
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
|
||||||
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
|
||||||
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
||||||
|
}
|
||||||
|
)+
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod unary {
|
||||||
|
unary!(cos, sin, exp);
|
||||||
|
}
|
||||||
|
|
||||||
static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
||||||
let mut l = HashMap::new();
|
let mut l = HashMap::new();
|
||||||
@ -104,24 +140,112 @@ impl Kernels {
|
|||||||
Ok(func)
|
Ok(func)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn call_unary(
|
|
||||||
&self,
|
|
||||||
device: &Device,
|
|
||||||
library_name: &'static str,
|
|
||||||
name: &'static str,
|
|
||||||
input: &Buffer,
|
|
||||||
output: &mut Buffer,
|
|
||||||
length: usize,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
|
||||||
let func = self.load_function(device, library_name, name)?;
|
|
||||||
call_unary(&func, input, output, length);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usize) {
|
pub fn call_unary_contiguous(
|
||||||
todo!("Call unary");
|
device: &Device,
|
||||||
|
command_buffer: &CommandBuffer,
|
||||||
|
kernels: &Kernels,
|
||||||
|
kernel_name: unary::contiguous::Kernel,
|
||||||
|
length: usize,
|
||||||
|
input: &Buffer,
|
||||||
|
output: &mut Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
assert_eq!(input.length(), output.length());
|
||||||
|
let func = kernels.load_function(device, "unary", kernel_name.0)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
encoder.set_bytes(0, 4, void_ptr(&length));
|
||||||
|
encoder.set_buffer(1, Some(&input), 0);
|
||||||
|
encoder.set_buffer(2, Some(&output), 0);
|
||||||
|
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: 1,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub fn call_unary_strided(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBuffer,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: unary::strided::Kernel,
|
||||||
|
input: &Buffer,
|
||||||
|
shape: &[usize],
|
||||||
|
strides: &[usize],
|
||||||
|
offset: usize,
|
||||||
|
output: &mut Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let func = kernels.load_function(device, "unary", name.0)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let num_dims: usize = shape.len() as usize;
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
let length: usize = shape.iter().product();
|
||||||
|
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
|
||||||
|
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
|
||||||
|
encoder.set_bytes(
|
||||||
|
2,
|
||||||
|
(shape.len() * std::mem::size_of::<usize>()) as u64,
|
||||||
|
shape.as_ptr() as *const c_void,
|
||||||
|
);
|
||||||
|
encoder.set_bytes(
|
||||||
|
3,
|
||||||
|
(strides.len() * std::mem::size_of::<usize>()) as u64,
|
||||||
|
strides.as_ptr() as *const c_void,
|
||||||
|
);
|
||||||
|
encoder.set_bytes(4, std::mem::size_of::<usize>() as u64, void_ptr(&offset));
|
||||||
|
|
||||||
|
encoder.set_buffer(5, Some(&input), 0);
|
||||||
|
encoder.set_buffer(6, Some(&output), 0);
|
||||||
|
|
||||||
|
let width = output.length();
|
||||||
|
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: 1,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
||||||
@ -132,9 +256,7 @@ pub fn void_ptr<T>(v: &T) -> *const c_void {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use metal::{
|
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||||
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLSize, NSUInteger,
|
|
||||||
};
|
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
|
||||||
fn device() -> Device {
|
fn device() -> Device {
|
||||||
@ -151,58 +273,63 @@ mod tests {
|
|||||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
|
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let kernels = Kernels::new();
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_owned_command_buffer();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let input = device.new_buffer_with_data(
|
let input = device.new_buffer_with_data(
|
||||||
v.as_ptr() as *const core::ffi::c_void,
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||||
options,
|
options,
|
||||||
);
|
);
|
||||||
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||||
let library = device
|
call_unary_contiguous(
|
||||||
.new_library_with_source(UNARY, &CompileOptions::new())
|
&device,
|
||||||
.expect("Failed to load unary library");
|
&command_buffer,
|
||||||
let func = library.get_function(&format!("cos_{name}"), None).unwrap();
|
&kernels,
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
name,
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
v.len(),
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
let pipeline = device
|
fn run_strided<T: Clone>(
|
||||||
.new_compute_pipeline_state_with_function(
|
v: &[T],
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
kernel: unary::strided::Kernel,
|
||||||
)
|
shape: &[usize],
|
||||||
.unwrap();
|
strides: &[usize],
|
||||||
|
offset: usize,
|
||||||
let dim: u32 = v.len() as u32;
|
) -> Vec<T> {
|
||||||
// let num_dims: u32 = 1;
|
let device = device();
|
||||||
// let info = [v.len() as u32, 1];
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let command_buffer = command_queue.new_owned_command_buffer();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
let input = device.new_buffer_with_data(
|
||||||
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
encoder.set_bytes(0, 4, void_ptr(&dim));
|
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||||
|
options,
|
||||||
encoder.set_buffer(1, Some(&input), 0);
|
);
|
||||||
encoder.set_buffer(2, Some(&output), 0);
|
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||||
|
let kernels = Kernels::new();
|
||||||
let width = v.len() as NSUInteger;
|
call_unary_strided(
|
||||||
|
&device,
|
||||||
let thread_group_count = MTLSize {
|
&command_buffer,
|
||||||
width,
|
&kernels,
|
||||||
height: 1,
|
kernel,
|
||||||
depth: 1,
|
&input,
|
||||||
};
|
shape,
|
||||||
|
strides,
|
||||||
let thread_group_size = MTLSize {
|
offset,
|
||||||
width: pipeline.max_total_threads_per_threadgroup(),
|
&mut output,
|
||||||
height: 1,
|
)
|
||||||
depth: 1,
|
.unwrap();
|
||||||
};
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
output.read_to_vec::<T>(v.len())
|
output.read_to_vec::<T>(v.len())
|
||||||
@ -211,10 +338,77 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn cos_f32() {
|
fn cos_f32() {
|
||||||
let v = vec![1.0f32, 2.0, 3.0];
|
let v = vec![1.0f32, 2.0, 3.0];
|
||||||
let results = run_cos(&v, "float");
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
||||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||||
|
|
||||||
|
let v = vec![1.0f32; 10_000];
|
||||||
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cos_f32_strided() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
// Shape = [6], strides = [1];
|
||||||
|
let shape = vec![6];
|
||||||
|
let strides = vec![1];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(expected, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Contiguous
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let shape = vec![3, 2];
|
||||||
|
let strides = vec![2, 1];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(expected, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Transposed
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let shape = vec![3, 2];
|
||||||
|
let strides = vec![1, 3];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(expected, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Very large
|
||||||
|
let v = vec![1.0f32; 10_000];
|
||||||
|
let shape = vec![2, 5_000];
|
||||||
|
let strides = vec![2, 1];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -360,7 +554,7 @@ mod tests {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|v| f16::from_f32(*v))
|
.map(|v| f16::from_f32(*v))
|
||||||
.collect();
|
.collect();
|
||||||
let results = run_cos(&v, "half");
|
let results = run(&v, unary::contiguous::cos::HALF);
|
||||||
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
||||||
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
|
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
|
||||||
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
|
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
|
||||||
|
@ -1,28 +1,19 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#
|
|
||||||
METAL_FUNC bool is_contiguous(
|
struct Info{
|
||||||
constant size_t &num_dims,
|
device size_t &num_dims;
|
||||||
constant size_t *dims,
|
device size_t *dims;
|
||||||
constant size_t *strides
|
device size_t *strides;
|
||||||
) {
|
};
|
||||||
size_t acc = 1;
|
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
|
||||||
uint dim_idx = num_dims - 1 - d;
|
|
||||||
if (acc != strides[dim_idx]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
acc *= dims[dim_idx];
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
constant size_t *dims,
|
constant size_t *dims,
|
||||||
constant size_t *strides
|
constant size_t *strides,
|
||||||
|
constant size_t &offset
|
||||||
) {
|
) {
|
||||||
uint strided_i = 0;
|
uint strided_i = offset;
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
uint dim_idx = num_dims - 1 - d;
|
uint dim_idx = num_dims - 1 - d;
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
@ -40,37 +31,40 @@ kernel void FN_NAME( \
|
|||||||
device const TYPENAME *input, \
|
device const TYPENAME *input, \
|
||||||
device TYPENAME *output, \
|
device TYPENAME *output, \
|
||||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
|
||||||
uint thread_index [[thread_index_in_threadgroup]] \
|
uint thread_index [[thread_index_in_threadgroup]] \
|
||||||
) { \
|
) { \
|
||||||
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||||
if (i > dim){ \
|
const size_t start = thread_index * length; \
|
||||||
return; \
|
const size_t stop = min(start + length, dim); \
|
||||||
|
for (size_t i = start; i < stop; i++){ \
|
||||||
|
output[i] = FN(input[i]); \
|
||||||
} \
|
} \
|
||||||
output[i] = FN(input[i]); \
|
|
||||||
}\
|
}\
|
||||||
kernel void FN_NAME_STRIDED( \
|
kernel void FN_NAME_STRIDED( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
constant size_t &num_dims, \
|
constant size_t &num_dims, \
|
||||||
constant size_t *info, \
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant size_t &offset, \
|
||||||
device const TYPENAME *input, \
|
device const TYPENAME *input, \
|
||||||
device TYPENAME *output, \
|
device TYPENAME *output, \
|
||||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
|
||||||
uint thread_index [[thread_index_in_threadgroup]] \
|
uint thread_index [[thread_index_in_threadgroup]] \
|
||||||
) { \
|
) { \
|
||||||
constant size_t *dims = info; \
|
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||||
constant size_t *strides = info + num_dims; \
|
const size_t start = thread_index * length; \
|
||||||
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
const size_t stop = min(start + length, dim); \
|
||||||
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
|
for (size_t i = start; i < stop; i++){ \
|
||||||
for (size_t i = start; i < stop; i++) { \
|
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides, offset)]); \
|
||||||
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
|
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
UNARY(cos, float, cos_float, cos_float_strided);
|
UNARY(cos, float, cos_float, cos_float_strided);
|
||||||
UNARY(cos, half, cos_half, cos_half_strided);
|
UNARY(cos, half, cos_half, cos_half_strided);
|
||||||
|
UNARY(sin, float, sin_float, sin_float_strided);
|
||||||
|
UNARY(sin, half, sin_half, sin_half_strided);
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided);
|
UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided);
|
||||||
|
UNARY(sin, bfloat, sin_bfloat, sin_bfloat_strided);
|
||||||
#endif
|
#endif
|
||||||
|
Reference in New Issue
Block a user