Debugging index_add.

This commit is contained in:
Ivar Flakstad
2023-11-03 12:08:58 +01:00
parent f57e3164ae
commit 0794e70a19
5 changed files with 183 additions and 28 deletions

View File

@ -1,8 +1,9 @@
use metal::{Buffer, Device, Function, Library, CompileOptions};
use metal::{Buffer, CompileOptions, Device, Function, Library, NSUInteger};
use std::collections::HashMap;
use std::sync::RwLock;
static UNARY: &'static str = include_str!("unary.metal");
pub const INDEXING: &str = include_str!("indexing.metal");
pub const UNARY: &str = include_str!("unary.metal");
pub enum Error {}
@ -63,10 +64,16 @@ fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) {
mod tests {
use super::*;
use metal::{
ComputePipelineDescriptor, MTLResourceOptions, MTLResourceUsage, MTLSize,
CompileOptions, ComputePipelineDescriptor, Device, FunctionConstantValues, MTLDataType,
MTLResourceOptions, MTLResourceUsage, MTLSize, NSUInteger,
};
use std::ffi::c_void;
use std::mem;
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32>{
pub fn void_ptr<T>(v: &T) -> *const c_void {
(v as *const T).cast()
}
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect()
}
@ -80,11 +87,11 @@ mod tests {
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<f32>()) as u64,
v.as_ptr() as *const c_void,
(v.len() * mem::size_of::<f32>()) as u64,
option,
);
let output = device.new_buffer((v.len() * core::mem::size_of::<f32>()) as u64, option);
let output = device.new_buffer((v.len() * mem::size_of::<f32>()) as u64, option);
let library = device
.new_library_with_source(UNARY, &CompileOptions::new())
.expect("Failed to load unary library");
@ -130,9 +137,93 @@ mod tests {
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
let results = output.read_to_vec::<f32>(v.len());
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
}
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let src_dim_size = 2u32;
let dst_dim_size = 2u32;
let left_size = left.len() as u32;
let right_size = right.len() as u32;
let numel = left_size * right_size;
let fcv = FunctionConstantValues::new();
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
fcv.set_constant_value_at_index(void_ptr(&src_dim_size), MTLDataType::UInt, 1);
fcv.set_constant_value_at_index(void_ptr(&dst_dim_size), MTLDataType::UInt, 2);
fcv.set_constant_value_at_index(void_ptr(&left_size), MTLDataType::UInt, 3);
fcv.set_constant_value_at_index(void_ptr(&right_size), MTLDataType::UInt, 4);
let function = library.get_function("index_add", Some(fcv)).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let options = MTLResourceOptions::StorageModeShared;
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
let ids = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
let inputs = device.new_buffer_with_data(void_ptr(&left), input_size, options);
let outputs = device.new_buffer_with_data(void_ptr(&right), output_size, options);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let thread_group_memory_length = output_size;
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
encoder.use_resource(&ids, MTLResourceUsage::Read);
encoder.use_resource(&inputs, MTLResourceUsage::Read);
encoder.use_resource(&outputs, MTLResourceUsage::Write);
encoder.set_buffer(0, Some(&ids), 0);
encoder.set_buffer(1, Some(&inputs), 0);
encoder.set_buffer(2, Some(&outputs), 0);
let width = 16;
let thread_group_count = MTLSize {
width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: (numel as NSUInteger + width) / width,
height: 1,
depth: 1,
};
println!("{:?}", thread_group_count);
println!("{:?}", thread_group_size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs.read_to_vec::<f32>(right.len());
println!("{:?}", result);
assert_eq!(result, expected);
}
}