mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Adding indexing.
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:

committed by
Nicolas Patry

parent
df6814f34e
commit
f82bf2d915
@ -690,6 +690,63 @@ pub fn call_where_cond_strided(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_index_select(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
ids_size: usize,
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
ids: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let src_dim_size = shape[dim];
|
||||
let dst_el = ids_size * left_size * right_size;
|
||||
|
||||
let func = kernels.load_function(device, Source::Indexing, name)?;
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&func)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
left_size,
|
||||
src_dim_size,
|
||||
right_size,
|
||||
ids_size,
|
||||
input,
|
||||
ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
|
||||
let grid_size = MTLSize {
|
||||
width: (dst_el as u64 + width - 1) / width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -1003,61 +1060,32 @@ mod tests {
|
||||
dim: usize,
|
||||
) -> Vec<T> {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
let options = CompileOptions::new();
|
||||
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
||||
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let src_dim_size = shape[dim];
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let ids_size = ids.len();
|
||||
|
||||
let function = library.get_function("is_u32_f32", None).unwrap();
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||
let ids_buffer = new_buffer(&device, &ids);
|
||||
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
left_size,
|
||||
src_dim_size,
|
||||
right_size,
|
||||
ids_size,
|
||||
&embeddings_buffer,
|
||||
&ids_buffer,
|
||||
&mut dst_buffer
|
||||
)
|
||||
);
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
"is_u32_f32",
|
||||
shape,
|
||||
ids.len(),
|
||||
dim,
|
||||
&embeddings_buffer,
|
||||
&ids_buffer,
|
||||
&mut dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
|
||||
let grid_size = MTLSize {
|
||||
width: (dst_el as u64 + width - 1) / width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
println!("{width:?} - {:?}", grid_size);
|
||||
|
||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
|
Reference in New Issue
Block a user