mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
index_add works
This commit is contained in:
@ -19,7 +19,7 @@ struct fault_counter {
|
|||||||
};
|
};
|
||||||
|
|
||||||
constant uint IDS_DIM_SIZE [[function_constant(0)]];
|
constant uint IDS_DIM_SIZE [[function_constant(0)]];
|
||||||
constant uint SRC_DIM_SIZE [[function_constant(1)]];
|
constant uint SRC_DIM_SIZE [[function_constant(1)]]; // Not needed
|
||||||
constant uint DST_DIM_SIZE [[function_constant(2)]];
|
constant uint DST_DIM_SIZE [[function_constant(2)]];
|
||||||
constant uint LEFT_SIZE [[function_constant(3)]];
|
constant uint LEFT_SIZE [[function_constant(3)]];
|
||||||
constant uint RIGHT_SIZE [[function_constant(4)]];
|
constant uint RIGHT_SIZE [[function_constant(4)]];
|
||||||
@ -29,21 +29,16 @@ kernel void index_add(
|
|||||||
device uint *ids [[buffer(0)]],
|
device uint *ids [[buffer(0)]],
|
||||||
device float *inp [[buffer(1)]],
|
device float *inp [[buffer(1)]],
|
||||||
device float *out [[buffer(2)]],
|
device float *out [[buffer(2)]],
|
||||||
|
uint thread_index [[thread_index_in_threadgroup]]
|
||||||
uint grid_size [[threadgroups_per_grid]], // gridDim
|
|
||||||
uint gid [[thread_position_in_grid]], // blockIdx
|
|
||||||
uint num_threads [[threads_per_grid]], // blockDim
|
|
||||||
uint thread_index [[thread_index_in_threadgroup]] // threadIdx
|
|
||||||
) {
|
) {
|
||||||
for (uint i = gid * num_threads + thread_index; i < NUMEL; i += num_threads * grid_size) {
|
const uint i = thread_index;
|
||||||
const uint pre = i / RIGHT_SIZE;
|
const uint pre = i / RIGHT_SIZE;
|
||||||
const uint post = i % RIGHT_SIZE;
|
const uint post = i % RIGHT_SIZE;
|
||||||
|
|
||||||
for (uint j = 0; j < IDS_DIM_SIZE; j++) {
|
for (uint j = 0; j < IDS_DIM_SIZE; ++j) {
|
||||||
const uint idx = ids[j];
|
const uint idx = ids[j];
|
||||||
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
|
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
|
||||||
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
|
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
|
||||||
out[dst_i] += inp[src_i];
|
out[dst_i] += inp[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
@ -1,4 +1,4 @@
|
|||||||
use metal::{Buffer, CompileOptions, Device, Function, Library, NSUInteger};
|
use metal::{Buffer, CompileOptions, Device, Function, Library};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ impl Kernels {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) {
|
fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usize) {
|
||||||
todo!("Call unary");
|
todo!("Call unary");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,12 +154,14 @@ mod tests {
|
|||||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||||
let right = [1.0f32; 15];
|
let right = [1.0f32; 15];
|
||||||
let index = [0u32, 4, 2];
|
let index = [0u32, 4, 2];
|
||||||
|
|
||||||
let ids_dim_size = index.len() as u32;
|
let ids_dim_size = index.len() as u32;
|
||||||
let src_dim_size = 2u32;
|
|
||||||
let dst_dim_size = 2u32;
|
// Are these reversed?
|
||||||
let left_size = left.len() as u32;
|
let src_dim_size: u32 = 9;
|
||||||
let right_size = right.len() as u32;
|
let dst_dim_size: u32 = 15;
|
||||||
let numel = left_size * right_size;
|
let left_size: u32 = 3;
|
||||||
|
let right_size: u32 = 3;
|
||||||
|
|
||||||
let fcv = FunctionConstantValues::new();
|
let fcv = FunctionConstantValues::new();
|
||||||
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
|
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
|
||||||
@ -200,18 +202,16 @@ mod tests {
|
|||||||
let width = 16;
|
let width = 16;
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
let thread_group_count = MTLSize {
|
||||||
width,
|
width: 1,
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
let thread_group_size = MTLSize {
|
||||||
width: (numel as NSUInteger + width) / width,
|
width,
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
println!("{:?}", thread_group_count);
|
|
||||||
println!("{:?}", thread_group_size);
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -222,8 +222,6 @@ mod tests {
|
|||||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||||
];
|
];
|
||||||
let result = outputs.read_to_vec::<f32>(right.len());
|
let result = outputs.read_to_vec::<f32>(right.len());
|
||||||
println!("{:?}", result);
|
|
||||||
|
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user