Adding indexing.

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-11-10 01:58:51 +01:00
parent 9a2784b8ab
commit 02c2ec2c71
4 changed files with 191 additions and 138 deletions

View File

@ -690,6 +690,63 @@ pub fn call_where_cond_strided(
Ok(())
}
pub fn call_index_select(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
ids_size: usize,
dim: usize,
input: &Buffer,
ids: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size;
let func = kernels.load_function(device, Source::Indexing, name)?;
let pipeline = device
.new_compute_pipeline_state_with_function(&func)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
ids_size,
input,
ids,
output
)
);
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
@ -1003,61 +1060,32 @@ mod tests {
dim: usize,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids.len() * left_size * right_size;
let ids_size = ids.len();
let function = library.get_function("is_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
ids_size,
&embeddings_buffer,
&ids_buffer,
&mut dst_buffer
)
);
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
&kernels,
"is_u32_f32",
shape,
ids.len(),
dim,
&embeddings_buffer,
&ids_buffer,
&mut dst_buffer,
)
.unwrap();
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
println!("{width:?} - {:?}", grid_size);
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();