Impl index_add via template for all types

This commit is contained in:
Ivar Flakstad
2023-11-04 08:46:08 +01:00
parent e708d35e7f
commit d4d6850c78
3 changed files with 103 additions and 70 deletions

View File

@ -1,44 +1,71 @@
#include <metal_stdlib>
#include <metal_config>
#define METAL_FUNC inline __attribute__((__always_inline__))
using namespace metal;
struct fault_counter {
uint counter;
uint tolerance;
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
device T *inp [[buffer(1)]],
device T *out [[buffer(2)]],
fault_counter(uint tolerance) {
this->counter = 0;
this->tolerance = tolerance;
}
constant uint &ids_dim_size,
constant uint &left_size,
constant uint &dst_dim_size,
constant uint &right_size,
bool quit() {
counter += 1;
return (counter > tolerance);
}
};
constant uint IDS_DIM_SIZE [[function_constant(0)]];
constant uint SRC_DIM_SIZE [[function_constant(1)]]; // Not needed
constant uint DST_DIM_SIZE [[function_constant(2)]];
constant uint LEFT_SIZE [[function_constant(3)]];
constant uint RIGHT_SIZE [[function_constant(4)]];
constant uint NUMEL = LEFT_SIZE * RIGHT_SIZE;
kernel void index_add(
device uint *ids [[buffer(0)]],
device float *inp [[buffer(1)]],
device float *out [[buffer(2)]],
uint threadgroup_size [[threads_per_threadgroup]],
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index [[thread_index_in_threadgroup]]
) {
const uint i = thread_index;
const uint pre = i / RIGHT_SIZE;
const uint post = i % RIGHT_SIZE;
for (uint j = 0; j < IDS_DIM_SIZE; ++j) {
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
if (gid >= left_size * right_size) {
return;
}
const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
}
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \
device TYPENAME *inp [[buffer(1)]], \
device TYPENAME *out [[buffer(2)]], \
constant uint &ids_dim_size, \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)

View File

@ -64,8 +64,8 @@ fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usiz
mod tests {
use super::*;
use metal::{
CompileOptions, ComputePipelineDescriptor, Device, FunctionConstantValues, MTLDataType,
MTLResourceOptions, MTLResourceUsage, MTLSize, NSUInteger,
CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLResourceUsage,
MTLSize, NSUInteger,
};
use std::ffi::c_void;
use std::mem;
@ -99,7 +99,7 @@ mod tests {
let argument_encoder = func.new_argument_encoder(0);
let arg_buffer = device.new_buffer(
argument_encoder.encoded_length(),
MTLResourceOptions::empty(),
MTLResourceOptions::StorageModeShared,
);
argument_encoder.set_argument_buffer(&arg_buffer, 0);
argument_encoder.set_buffer(0, &input, 0);
@ -154,66 +154,54 @@ mod tests {
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
// Are these reversed?
let src_dim_size: u32 = 9;
let dst_dim_size: u32 = 15;
let left_size: u32 = 3;
let right_size: u32 = 3;
let fcv = FunctionConstantValues::new();
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
fcv.set_constant_value_at_index(void_ptr(&src_dim_size), MTLDataType::UInt, 1);
fcv.set_constant_value_at_index(void_ptr(&dst_dim_size), MTLDataType::UInt, 2);
fcv.set_constant_value_at_index(void_ptr(&left_size), MTLDataType::UInt, 3);
fcv.set_constant_value_at_index(void_ptr(&right_size), MTLDataType::UInt, 4);
let function = library.get_function("index_add", Some(fcv)).unwrap();
let function = library.get_function("ia_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let options = MTLResourceOptions::StorageModeShared;
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
let ids = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
let inputs = device.new_buffer_with_data(void_ptr(&left), input_size, options);
let outputs = device.new_buffer_with_data(void_ptr(&right), output_size, options);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
encoder.set_compute_pipeline_state(&pipeline);
let thread_group_memory_length = output_size;
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
encoder.use_resource(&ids, MTLResourceUsage::Read);
encoder.use_resource(&inputs, MTLResourceUsage::Read);
encoder.use_resource(&outputs, MTLResourceUsage::Write);
let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options);
let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options);
encoder.set_buffer(0, Some(&ids), 0);
encoder.set_buffer(1, Some(&inputs), 0);
encoder.set_buffer(2, Some(&outputs), 0);
let width = 16;
encoder.set_buffer(0, Some(&index_buffer), 0);
encoder.set_buffer(1, Some(&inputs_buffer), 0);
encoder.set_buffer(2, Some(&outputs_buffer), 0);
let thread_group_count = MTLSize {
width: 1,
encoder.set_bytes(3, 4, void_ptr(&ids_dim_size));
encoder.set_bytes(4, 4, void_ptr(&left_size));
encoder.set_bytes(5, 4, void_ptr(&dst_dim_size));
encoder.set_bytes(6, 4, void_ptr(&right_size));
let grid_size = MTLSize {
width: right.len() as NSUInteger,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
width: pipeline.max_total_threads_per_threadgroup(),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.dispatch_threads(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
@ -221,7 +209,7 @@ mod tests {
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs.read_to_vec::<f32>(right.len());
let result = outputs_buffer.read_to_vec::<f32>(right.len());
assert_eq!(result, expected);
}
}

View File

@ -0,0 +1,18 @@
#include <metal_stdlib>
using namespace metal;
struct fault_counter {
uint counter;
uint tolerance;
fault_counter(uint tolerance) {
this->counter = 0;
this->tolerance = tolerance;
}
bool quit() {
counter += 1;
return (counter > tolerance);
}
};