mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Finished the unary
- Added proper kernel type check (through modules + macro) - split contiguous and strided into 2 different kernels - Verified on long range + strided values.
This commit is contained in:
@ -1,12 +1,48 @@
|
||||
use metal::{Buffer, CompileOptions, Device, Function, Library};
|
||||
use metal::{
|
||||
Buffer, CommandBuffer, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
|
||||
MTLSize,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
pub const AFFINE: &str = include_str!("affine.metal");
|
||||
pub const INDEXING: &str = include_str!("indexing.metal");
|
||||
pub const UNARY: &str = include_str!("unary.metal");
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
|
||||
macro_rules! unary{
|
||||
($($name:ident),+) => {
|
||||
|
||||
pub mod contiguous {
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
||||
}
|
||||
)+
|
||||
}
|
||||
|
||||
pub mod strided {
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
||||
}
|
||||
)+
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub mod unary {
|
||||
unary!(cos, sin, exp);
|
||||
}
|
||||
|
||||
static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
||||
let mut l = HashMap::new();
|
||||
@ -104,24 +140,112 @@ impl Kernels {
|
||||
Ok(func)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_unary(
|
||||
&self,
|
||||
device: &Device,
|
||||
library_name: &'static str,
|
||||
name: &'static str,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
length: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = self.load_function(device, library_name, name)?;
|
||||
call_unary(&func, input, output, length);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usize) {
|
||||
todo!("Call unary");
|
||||
pub fn call_unary_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBuffer,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, "unary", kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&length));
|
||||
encoder.set_buffer(1, Some(&input), 0);
|
||||
encoder.set_buffer(2, Some(&output), 0);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
pub fn call_unary_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBuffer,
|
||||
kernels: &Kernels,
|
||||
name: unary::strided::Kernel,
|
||||
input: &Buffer,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, "unary", name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let num_dims: usize = shape.len() as usize;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
|
||||
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
|
||||
encoder.set_bytes(
|
||||
2,
|
||||
(shape.len() * std::mem::size_of::<usize>()) as u64,
|
||||
shape.as_ptr() as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
3,
|
||||
(strides.len() * std::mem::size_of::<usize>()) as u64,
|
||||
strides.as_ptr() as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(4, std::mem::size_of::<usize>() as u64, void_ptr(&offset));
|
||||
|
||||
encoder.set_buffer(5, Some(&input), 0);
|
||||
encoder.set_buffer(6, Some(&output), 0);
|
||||
|
||||
let width = output.length();
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
||||
@ -132,9 +256,7 @@ pub fn void_ptr<T>(v: &T) -> *const c_void {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use half::f16;
|
||||
use metal::{
|
||||
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLSize, NSUInteger,
|
||||
};
|
||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
use std::mem;
|
||||
|
||||
fn device() -> Device {
|
||||
@ -151,58 +273,63 @@ mod tests {
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
|
||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let command_buffer = command_queue.new_owned_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||
let library = device
|
||||
.new_library_with_source(UNARY, &CompileOptions::new())
|
||||
.expect("Failed to load unary library");
|
||||
let func = library.get_function(&format!("cos_{name}"), None).unwrap();
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||
call_unary_contiguous(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let dim: u32 = v.len() as u32;
|
||||
// let num_dims: u32 = 1;
|
||||
// let info = [v.len() as u32, 1];
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&dim));
|
||||
|
||||
encoder.set_buffer(1, Some(&input), 0);
|
||||
encoder.set_buffer(2, Some(&output), 0);
|
||||
|
||||
let width = v.len() as NSUInteger;
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: pipeline.max_total_threads_per_threadgroup(),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
fn run_strided<T: Clone>(
|
||||
v: &[T],
|
||||
kernel: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_owned_command_buffer();
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
kernel,
|
||||
&input,
|
||||
shape,
|
||||
strides,
|
||||
offset,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
@ -211,10 +338,77 @@ mod tests {
|
||||
#[test]
|
||||
fn cos_f32() {
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let results = run_cos(&v, "float");
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f32_strided() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
// Shape = [6], strides = [1];
|
||||
let shape = vec![6];
|
||||
let strides = vec![1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Contiguous
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![3, 2];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Transposed
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![3, 2];
|
||||
let strides = vec![1, 3];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Very large
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -360,7 +554,7 @@ mod tests {
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let results = run_cos(&v, "half");
|
||||
let results = run(&v, unary::contiguous::cos::HALF);
|
||||
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
||||
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
|
||||
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
|
||||
|
@ -1,28 +1,19 @@
|
||||
#include <metal_stdlib>
|
||||
#
|
||||
METAL_FUNC bool is_contiguous(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
size_t acc = 1;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
if (acc != strides[dim_idx]) {
|
||||
return false;
|
||||
}
|
||||
acc *= dims[dim_idx];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct Info{
|
||||
device size_t &num_dims;
|
||||
device size_t *dims;
|
||||
device size_t *strides;
|
||||
};
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
constant size_t *strides,
|
||||
constant size_t &offset
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
uint strided_i = offset;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
@ -40,37 +31,40 @@ kernel void FN_NAME( \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||
if (i > dim){ \
|
||||
return; \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
output[i] = FN(input[i]); \
|
||||
} \
|
||||
output[i] = FN(input[i]); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *info, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &offset, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
constant size_t *dims = info; \
|
||||
constant size_t *strides = info + num_dims; \
|
||||
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
|
||||
for (size_t i = start; i < stop; i++) { \
|
||||
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides, offset)]); \
|
||||
} \
|
||||
}
|
||||
|
||||
UNARY(cos, float, cos_float, cos_float_strided);
|
||||
UNARY(cos, half, cos_half, cos_half_strided);
|
||||
UNARY(sin, float, sin_float, sin_float_strided);
|
||||
UNARY(sin, half, sin_half, sin_half_strided);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided);
|
||||
UNARY(sin, bfloat, sin_bfloat, sin_bfloat_strided);
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user