Finished the unary

- Added proper kernel type check (through modules + macro)
- split contiguous and strided into 2 different kernels
- Verified on long range + strided values.
This commit is contained in:
Nicolas Patry
2023-11-06 23:12:12 +01:00
parent cd68c96803
commit 7ff17d92b3
2 changed files with 288 additions and 100 deletions

View File

@ -1,12 +1,48 @@
use metal::{Buffer, CompileOptions, Device, Function, Library};
use metal::{
Buffer, CommandBuffer, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
MTLSize,
};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
pub const AFFINE: &str = include_str!("affine.metal");
pub const INDEXING: &str = include_str!("indexing.metal");
pub const UNARY: &str = include_str!("unary.metal");
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
macro_rules! unary{
($($name:ident),+) => {
pub mod contiguous {
pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
}
)+
}
pub mod strided {
pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
}
)+
}
};
}
pub mod unary {
unary!(cos, sin, exp);
}
static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
let mut l = HashMap::new();
@ -104,24 +140,112 @@ impl Kernels {
Ok(func)
}
}
pub fn call_unary(
&self,
device: &Device,
library_name: &'static str,
name: &'static str,
input: &Buffer,
output: &mut Buffer,
length: usize,
) -> Result<(), MetalKernelError> {
let func = self.load_function(device, library_name, name)?;
call_unary(&func, input, output, length);
Ok(())
}
}
fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usize) {
todo!("Call unary");
pub fn call_unary_contiguous(
device: &Device,
command_buffer: &CommandBuffer,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, "unary", kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length));
encoder.set_buffer(1, Some(&input), 0);
encoder.set_buffer(2, Some(&output), 0);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBuffer,
kernels: &Kernels,
name: unary::strided::Kernel,
input: &Buffer,
shape: &[usize],
strides: &[usize],
offset: usize,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, "unary", name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let num_dims: usize = shape.len() as usize;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
encoder.set_bytes(
2,
(shape.len() * std::mem::size_of::<usize>()) as u64,
shape.as_ptr() as *const c_void,
);
encoder.set_bytes(
3,
(strides.len() * std::mem::size_of::<usize>()) as u64,
strides.as_ptr() as *const c_void,
);
encoder.set_bytes(4, std::mem::size_of::<usize>() as u64, void_ptr(&offset));
encoder.set_buffer(5, Some(&input), 0);
encoder.set_buffer(6, Some(&output), 0);
let width = output.length();
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn void_ptr<T>(v: &T) -> *const c_void {
@ -132,9 +256,7 @@ pub fn void_ptr<T>(v: &T) -> *const c_void {
mod tests {
use super::*;
use half::f16;
use metal::{
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLSize, NSUInteger,
};
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
use std::mem;
fn device() -> Device {
@ -151,58 +273,63 @@ mod tests {
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let options = MTLResourceOptions::StorageModeManaged;
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let command_buffer = command_queue.new_owned_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
let library = device
.new_library_with_source(UNARY, &CompileOptions::new())
.expect("Failed to load unary library");
let func = library.get_function(&format!("cos_{name}"), None).unwrap();
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
call_unary_contiguous(
&device,
&command_buffer,
&kernels,
name,
v.len(),
&input,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let dim: u32 = v.len() as u32;
// let num_dims: u32 = 1;
// let info = [v.len() as u32, 1];
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&dim));
encoder.set_buffer(1, Some(&input), 0);
encoder.set_buffer(2, Some(&output), 0);
let width = v.len() as NSUInteger;
let thread_group_count = MTLSize {
width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: pipeline.max_total_threads_per_threadgroup(),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
fn run_strided<T: Clone>(
v: &[T],
kernel: unary::strided::Kernel,
shape: &[usize],
strides: &[usize],
offset: usize,
) -> Vec<T> {
let device = device();
let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_owned_command_buffer();
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
let kernels = Kernels::new();
call_unary_strided(
&device,
&command_buffer,
&kernels,
kernel,
&input,
shape,
strides,
offset,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
@ -211,10 +338,77 @@ mod tests {
#[test]
fn cos_f32() {
let v = vec![1.0f32, 2.0, 3.0];
let results = run_cos(&v, "float");
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
let v = vec![1.0f32; 10_000];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn cos_f32_strided() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
// Shape = [6], strides = [1];
let shape = vec![6];
let strides = vec![1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Contiguous
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Transposed
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![1, 3];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Very large
let v = vec![1.0f32; 10_000];
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
@ -360,7 +554,7 @@ mod tests {
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let results = run_cos(&v, "half");
let results = run(&v, unary::contiguous::cos::HALF);
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);

View File

@ -1,28 +1,19 @@
#include <metal_stdlib>
#
METAL_FUNC bool is_contiguous(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
size_t acc = 1;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
if (acc != strides[dim_idx]) {
return false;
}
acc *= dims[dim_idx];
}
return true;
}
struct Info{
device size_t &num_dims;
device size_t *dims;
device size_t *strides;
};
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
constant size_t *strides,
constant size_t &offset
) {
uint strided_i = 0;
uint strided_i = offset;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
@ -40,37 +31,40 @@ kernel void FN_NAME( \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
if (i > dim){ \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = FN(input[i]); \
} \
output[i] = FN(input[i]); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *info, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &offset, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
constant size_t *dims = info; \
constant size_t *strides = info + num_dims; \
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
for (size_t i = start; i < stop; i++) { \
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides, offset)]); \
} \
}
UNARY(cos, float, cos_float, cos_float_strided);
UNARY(cos, half, cos_half, cos_half_strided);
UNARY(sin, float, sin_float, sin_float_strided);
UNARY(sin, half, sin_half, sin_half_strided);
#if __METAL_VERSION__ >= 310
UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided);
UNARY(sin, bfloat, sin_bfloat, sin_bfloat_strided);
#endif