Writing unary as macro instead, protecting bfloat type with proper metal version.

This commit is contained in:
Nicolas Patry
2023-11-06 15:36:48 +01:00
parent 63cce76b84
commit dedc8c3656
4 changed files with 56 additions and 32 deletions

View File

@ -12,3 +12,6 @@ license.workspace = true
metal = { workspace = true }
once_cell = "1.18.0"
thiserror = { workspace = true }
[dev-dependencies]
half = { workspace = true }

View File

@ -48,9 +48,13 @@ kernel void FN_NAME( \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
#endif
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)

View File

@ -131,11 +131,10 @@ pub fn void_ptr<T>(v: &T) -> *const c_void {
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
use metal::{
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLResourceUsage,
MTLSize, NSUInteger,
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLSize, NSUInteger,
};
use std::ffi::c_void;
use std::mem;
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
@ -143,32 +142,26 @@ mod tests {
v.iter().map(|t| f32::round(t * b) / b).collect()
}
#[test]
fn cos() {
let v = vec![1.0f32, 2.0, 3.0];
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
let device = Device::system_default().unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let input = device.new_buffer_with_data(
v.as_ptr() as *const c_void,
(v.len() * mem::size_of::<f32>()) as u64,
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
option,
);
let output = device.new_buffer((v.len() * mem::size_of::<f32>()) as u64, option);
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, option);
let library = device
.new_library_with_source(UNARY, &CompileOptions::new())
.expect("Failed to load unary library");
let func = library.get_function("cos", None).unwrap();
let argument_encoder = func.new_argument_encoder(0);
let arg_buffer = device.new_buffer(
argument_encoder.encoded_length(),
MTLResourceOptions::StorageModeShared,
);
argument_encoder.set_argument_buffer(&arg_buffer, 0);
argument_encoder.set_buffer(0, &input, 0);
argument_encoder.set_buffer(1, &output, 0);
let func = library.get_function(&format!("cos_{name}"), None).unwrap();
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
@ -178,11 +171,10 @@ mod tests {
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&arg_buffer), 0);
encoder.use_resource(&input, MTLResourceUsage::Read);
encoder.use_resource(&output, MTLResourceUsage::Write);
encoder.set_buffer(0, Some(&input), 0);
encoder.set_buffer(1, Some(&output), 0);
let width = 16;
@ -202,9 +194,14 @@ mod tests {
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
#[test]
fn cos_f32() {
let v = vec![1.0f32, 2.0, 3.0];
let results = run_cos(&v, "float");
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
let results = output.read_to_vec::<f32>(v.len());
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
}
@ -343,4 +340,16 @@ mod tests {
let result = outputs_buffer.read_to_vec::<f32>(right.len());
assert_eq!(result, expected);
}
#[test]
fn cos_f16() {
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let results = run_cos(&v, "half");
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
}
}

View File

@ -2,13 +2,21 @@
using namespace metal;
struct Input {
device float *input;
device float *output;
};
kernel void cos(device Input& args [[ buffer(0) ]], uint index [[thread_position_in_grid]])
template <typename T>
kernel void unary_cos(device const T *input, device T *output, uint index [[thread_position_in_grid]])
{
args.output[index] = cos(args.input[index]);
output[index] = cos(input[index]);
}
#define UNARY(FN, TYPENAME, FN_NAME) \
kernel void FN_NAME(device const TYPENAME *input, device TYPENAME *output, uint index [[thread_position_in_grid]]) \
{ \
output[index] = FN(input[index]);\
}
UNARY(cos, float, cos_float);
UNARY(cos, half, cos_half);
#if __METAL_VERSION__ >= 310
UNARY(cos, half, cos_half);
#endif