index_add works

This commit is contained in:
Ivar Flakstad
2023-11-03 21:12:52 +01:00
parent 0794e70a19
commit e708d35e7f
2 changed files with 21 additions and 28 deletions

View File

@ -19,7 +19,7 @@ struct fault_counter {
};
constant uint IDS_DIM_SIZE [[function_constant(0)]];
constant uint SRC_DIM_SIZE [[function_constant(1)]];
constant uint SRC_DIM_SIZE [[function_constant(1)]]; // Not needed
constant uint DST_DIM_SIZE [[function_constant(2)]];
constant uint LEFT_SIZE [[function_constant(3)]];
constant uint RIGHT_SIZE [[function_constant(4)]];
@ -29,21 +29,16 @@ kernel void index_add(
device uint *ids [[buffer(0)]],
device float *inp [[buffer(1)]],
device float *out [[buffer(2)]],
uint grid_size [[threadgroups_per_grid]], // gridDim
uint gid [[thread_position_in_grid]], // blockIdx
uint num_threads [[threads_per_grid]], // blockDim
uint thread_index [[thread_index_in_threadgroup]] // threadIdx
uint thread_index [[thread_index_in_threadgroup]]
) {
for (uint i = gid * num_threads + thread_index; i < NUMEL; i += num_threads * grid_size) {
const uint pre = i / RIGHT_SIZE;
const uint post = i % RIGHT_SIZE;
const uint i = thread_index;
const uint pre = i / RIGHT_SIZE;
const uint post = i % RIGHT_SIZE;
for (uint j = 0; j < IDS_DIM_SIZE; j++) {
const uint idx = ids[j];
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
out[dst_i] += inp[src_i];
}
for (uint j = 0; j < IDS_DIM_SIZE; ++j) {
const uint idx = ids[j];
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
out[dst_i] += inp[src_i];
}
}

View File

@ -1,4 +1,4 @@
use metal::{Buffer, CompileOptions, Device, Function, Library, NSUInteger};
use metal::{Buffer, CompileOptions, Device, Function, Library};
use std::collections::HashMap;
use std::sync::RwLock;
@ -56,7 +56,7 @@ impl Kernels {
}
}
fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) {
fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usize) {
todo!("Call unary");
}
@ -154,12 +154,14 @@ mod tests {
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let src_dim_size = 2u32;
let dst_dim_size = 2u32;
let left_size = left.len() as u32;
let right_size = right.len() as u32;
let numel = left_size * right_size;
// Are these reversed?
let src_dim_size: u32 = 9;
let dst_dim_size: u32 = 15;
let left_size: u32 = 3;
let right_size: u32 = 3;
let fcv = FunctionConstantValues::new();
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
@ -200,18 +202,16 @@ mod tests {
let width = 16;
let thread_group_count = MTLSize {
width,
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: (numel as NSUInteger + width) / width,
width,
height: 1,
depth: 1,
};
println!("{:?}", thread_group_count);
println!("{:?}", thread_group_size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -222,8 +222,6 @@ mod tests {
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs.read_to_vec::<f32>(right.len());
println!("{:?}", result);
assert_eq!(result, expected);
}
}