Scatter add.

This commit is contained in:
Nicolas Patry
2023-12-18 10:32:22 +01:00
parent 586b6f6fff
commit 6a3ca7da0c
3 changed files with 147 additions and 17 deletions

View File

@ -1020,7 +1020,9 @@ pub fn call_gather(
ids_size: usize,
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@ -1043,8 +1045,60 @@ pub fn call_gather(
src_dim_size,
right_size,
ids_size,
input,
ids,
(input, input_offset),
(ids, ids_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
pub fn call_scatter_add(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
src_shape: &[usize],
dst_shape: &[usize],
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
let right_size: usize = src_shape[dim + 1..].iter().product();
let src_dim_size = src_shape[dim];
let dst_el = left_size * right_size;
let dst_dim_size = dst_shape[dim];
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
dst_dim_size,
(input, input_offset),
(ids, ids_offset),
output
)
);