Working but failing tests because of threadgroup.

This commit is contained in:
Nicolas Patry
2023-11-06 17:04:47 +01:00
parent dedc8c3656
commit 677495f9b8
3 changed files with 92 additions and 25 deletions

View File

@ -59,4 +59,4 @@ kernel void affine(
out[strided_i] = x * mul + add; out[strided_i] = x * mul + add;
} }
} }
} }

View File

@ -137,6 +137,10 @@ mod tests {
}; };
use std::mem; use std::mem;
fn device() -> Device {
Device::system_default().unwrap()
}
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> { fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits); let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect() v.iter().map(|t| f32::round(t * b) / b).collect()
@ -148,8 +152,9 @@ mod tests {
} }
fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> { fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
let device = device();
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache; let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
let device = Device::system_default().unwrap(); let option = metal::MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = device.new_buffer_with_data( let input = device.new_buffer_with_data(
@ -165,18 +170,27 @@ mod tests {
let pipeline_state_descriptor = ComputePipelineDescriptor::new(); let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func)); pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline_state = device let pipeline = device
.new_compute_pipeline_state_with_function( .new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(), pipeline_state_descriptor.compute_function().unwrap(),
) )
.unwrap(); .unwrap();
let encoder = command_buffer.new_compute_command_encoder(); let dim: u32 = v.len() as u32;
encoder.set_compute_pipeline_state(&pipeline_state); // let num_dims: u32 = 1;
encoder.set_buffer(0, Some(&input), 0); // let info = [v.len() as u32, 1];
encoder.set_buffer(1, Some(&output), 0);
let width = 16; let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&dim));
// encoder.set_bytes(1, 4, void_ptr(&num_dims));
// encoder.set_bytes(2, 4, void_ptr(&info));
encoder.set_buffer(1, Some(&input), 0);
encoder.set_buffer(2, Some(&output), 0);
let width = v.len() as NSUInteger;
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width, width,
@ -185,7 +199,7 @@ mod tests {
}; };
let thread_group_size = MTLSize { let thread_group_size = MTLSize {
width: (v.len() as u64 + width) / width, width: pipeline.max_total_threads_per_threadgroup(),
height: 1, height: 1,
depth: 1, depth: 1,
}; };
@ -208,7 +222,7 @@ mod tests {
#[test] #[test]
fn affine() { fn affine() {
let device = Device::system_default().expect("no device found"); let device = device();
let options = CompileOptions::new(); let options = CompileOptions::new();
let library = device.new_library_with_source(AFFINE, &options).unwrap(); let library = device.new_library_with_source(AFFINE, &options).unwrap();
@ -225,7 +239,8 @@ mod tests {
let pipeline = device let pipeline = device
.new_compute_pipeline_state_with_function(&function) .new_compute_pipeline_state_with_function(&function)
.unwrap(); .unwrap();
let options = MTLResourceOptions::StorageModeShared; // let options = MTLResourceOptions::StorageModeShared;
let options = metal::MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
@ -291,7 +306,7 @@ mod tests {
let pipeline = device let pipeline = device
.new_compute_pipeline_state_with_function(&function) .new_compute_pipeline_state_with_function(&function)
.unwrap(); .unwrap();
let options = MTLResourceOptions::StorageModeShared; let options = metal::MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();

View File

@ -1,22 +1,74 @@
#include <metal_stdlib> #include <metal_stdlib>
#
METAL_FUNC bool is_contiguous(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
size_t acc = 1;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
if (acc != strides[dim_idx]) {
return false;
}
acc *= dims[dim_idx];
}
return true;
}
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal; using namespace metal;
template <typename T> #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void unary_cos(device const T *input, device T *output, uint index [[thread_position_in_grid]]) kernel void FN_NAME( \
{ constant size_t &dim, \
output[index] = cos(input[index]); device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
output[i] = FN(input[i]); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *info, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
constant size_t *dims = info; \
constant size_t *strides = info + num_dims; \
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
for (size_t i = start; i < stop; i++) { \
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
output[i] = 1; \
} \
} }
#define UNARY(FN, TYPENAME, FN_NAME) \ UNARY(cos, float, cos_float, cos_float_strided);
kernel void FN_NAME(device const TYPENAME *input, device TYPENAME *output, uint index [[thread_position_in_grid]]) \ UNARY(cos, half, cos_half, cos_half_strided);
{ \
output[index] = FN(input[index]);\
}
UNARY(cos, float, cos_float);
UNARY(cos, half, cos_half);
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
UNARY(cos, half, cos_half); UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided);
#endif #endif