mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
index_add works
This commit is contained in:
@ -19,7 +19,7 @@ struct fault_counter {
|
||||
};
|
||||
|
||||
constant uint IDS_DIM_SIZE [[function_constant(0)]];
|
||||
constant uint SRC_DIM_SIZE [[function_constant(1)]];
|
||||
constant uint SRC_DIM_SIZE [[function_constant(1)]]; // Not needed
|
||||
constant uint DST_DIM_SIZE [[function_constant(2)]];
|
||||
constant uint LEFT_SIZE [[function_constant(3)]];
|
||||
constant uint RIGHT_SIZE [[function_constant(4)]];
|
||||
@ -29,21 +29,16 @@ kernel void index_add(
|
||||
device uint *ids [[buffer(0)]],
|
||||
device float *inp [[buffer(1)]],
|
||||
device float *out [[buffer(2)]],
|
||||
|
||||
uint grid_size [[threadgroups_per_grid]], // gridDim
|
||||
uint gid [[thread_position_in_grid]], // blockIdx
|
||||
uint num_threads [[threads_per_grid]], // blockDim
|
||||
uint thread_index [[thread_index_in_threadgroup]] // threadIdx
|
||||
uint thread_index [[thread_index_in_threadgroup]]
|
||||
) {
|
||||
for (uint i = gid * num_threads + thread_index; i < NUMEL; i += num_threads * grid_size) {
|
||||
const uint pre = i / RIGHT_SIZE;
|
||||
const uint post = i % RIGHT_SIZE;
|
||||
const uint i = thread_index;
|
||||
const uint pre = i / RIGHT_SIZE;
|
||||
const uint post = i % RIGHT_SIZE;
|
||||
|
||||
for (uint j = 0; j < IDS_DIM_SIZE; j++) {
|
||||
const uint idx = ids[j];
|
||||
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
|
||||
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
for (uint j = 0; j < IDS_DIM_SIZE; ++j) {
|
||||
const uint idx = ids[j];
|
||||
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
|
||||
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
use metal::{Buffer, CompileOptions, Device, Function, Library, NSUInteger};
|
||||
use metal::{Buffer, CompileOptions, Device, Function, Library};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
@ -56,7 +56,7 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) {
|
||||
fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usize) {
|
||||
todo!("Call unary");
|
||||
}
|
||||
|
||||
@ -154,12 +154,14 @@ mod tests {
|
||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||
let right = [1.0f32; 15];
|
||||
let index = [0u32, 4, 2];
|
||||
|
||||
let ids_dim_size = index.len() as u32;
|
||||
let src_dim_size = 2u32;
|
||||
let dst_dim_size = 2u32;
|
||||
let left_size = left.len() as u32;
|
||||
let right_size = right.len() as u32;
|
||||
let numel = left_size * right_size;
|
||||
|
||||
// Are these reversed?
|
||||
let src_dim_size: u32 = 9;
|
||||
let dst_dim_size: u32 = 15;
|
||||
let left_size: u32 = 3;
|
||||
let right_size: u32 = 3;
|
||||
|
||||
let fcv = FunctionConstantValues::new();
|
||||
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
|
||||
@ -200,18 +202,16 @@ mod tests {
|
||||
let width = 16;
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width,
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: (numel as NSUInteger + width) / width,
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
println!("{:?}", thread_group_count);
|
||||
println!("{:?}", thread_group_size);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -222,8 +222,6 @@ mod tests {
|
||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||
];
|
||||
let result = outputs.read_to_vec::<f32>(right.len());
|
||||
println!("{:?}", result);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user