mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Scatter add.
This commit is contained in:
@ -1020,7 +1020,9 @@ pub fn call_gather(
|
||||
ids_size: usize,
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
@ -1043,8 +1045,60 @@ pub fn call_gather(
|
||||
src_dim_size,
|
||||
right_size,
|
||||
ids_size,
|
||||
input,
|
||||
ids,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_scatter_add(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
src_shape: &[usize],
|
||||
dst_shape: &[usize],
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||
let src_dim_size = src_shape[dim];
|
||||
let dst_el = left_size * right_size;
|
||||
let dst_dim_size = dst_shape[dim];
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
left_size,
|
||||
src_dim_size,
|
||||
right_size,
|
||||
dst_dim_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
Reference in New Issue
Block a user