Adding indexing.

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-11-10 01:58:51 +01:00
parent 9a2784b8ab
commit 02c2ec2c71
4 changed files with 191 additions and 138 deletions

View File

@ -479,28 +479,40 @@ impl BackendStorage for MetalStorage {
todo!()
}
fn index_select(
&self,
_ids: &Self,
_src_l: &Layout,
_ids_l: &Layout,
_dim: usize,
) -> Result<Self> {
todo!("Index select");
// let ids_shape = ids_l.shape();
// let left_size: usize = src_l.dims()[..dim].iter().product();
// let right_size: usize = src_l.dims()[dim + 1..].iter().product();
// let src_dim_size = src_l.dims()[dim];
// let ids_dim_size = ids_shape.elem_count();
// let dst_el = ids_shape.elem_count() * left_size * right_size;
// let dtype = self.dtype;
// let device = self.device();
// let buffer = device.new_buffer(dst_el, dtype);
// Ok(Self {
// buffer,
// device: device.clone(),
// dtype,
// })
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
assert!(src_l.is_contiguous());
assert!(ids_l.is_contiguous());
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let ids_el = ids_l.shape().elem_count();
let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
let mut buffer = device.new_buffer(dst_el, dtype);
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
(left, right) => todo!("index select metal {left:?} {right:?}"),
};
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
ids_el,
dim,
&self.buffer,
&ids.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn index_add(

View File

@ -1,39 +1,36 @@
#include <metal_stdlib>
using namespace metal;
kernel void is_u32_f32(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device float *input,
const device uint *input_ids,
device float *output,
uint gid [[ thread_position_in_grid ]]
) {
if (gid >= dst_size) {
return;
}
const size_t id_i = gid / right_size / left_size;
const size_t right_rank_i = gid % right_size;
const size_t left_rank_i = gid % left_size;
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
const uint input_i = min(input_ids[id_i], (uint)(src_dim_size - 1));
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i;
output[gid] = input[src_i];
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \
) { \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = gid / right_size / left_size; \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid % left_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
output[gid] = input[src_i]; \
}
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
@ -82,6 +79,7 @@ kernel void FN_NAME( \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
INDEX_OP(is_u32_f32, uint, float)
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)

View File

@ -690,6 +690,63 @@ pub fn call_where_cond_strided(
Ok(())
}
pub fn call_index_select(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
ids_size: usize,
dim: usize,
input: &Buffer,
ids: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size;
let func = kernels.load_function(device, Source::Indexing, name)?;
let pipeline = device
.new_compute_pipeline_state_with_function(&func)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
ids_size,
input,
ids,
output
)
);
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
@ -1003,61 +1060,32 @@ mod tests {
dim: usize,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids.len() * left_size * right_size;
let ids_size = ids.len();
let function = library.get_function("is_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
ids_size,
&embeddings_buffer,
&ids_buffer,
&mut dst_buffer
)
);
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
&kernels,
"is_u32_f32",
shape,
ids.len(),
dim,
&embeddings_buffer,
&ids_buffer,
&mut dst_buffer,
)
.unwrap();
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
println!("{width:?} - {:?}", grid_size);
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();

View File

@ -18,45 +18,55 @@ METAL_FUNC uint get_strided_index(
constant int THREADGROUP_SIZE = 256;
kernel void fast_sum_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {
threadgroup float shared_memory[THREADGROUP_SIZE];
shared_memory[tid] = 0;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
shared_memory[tid] += src[idx];
idx += blockDim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
dst[dst_id] = shared_memory[0];
}
# define REDUCE(FN, NAME, TYPENAME) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint blockDim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = 0; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = src[idx]; \
shared_memory[tid] = FN; \
idx += blockDim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
kernel void softmax_float(
constant size_t &src_numel,
@ -122,3 +132,8 @@ kernel void softmax_float(
idx += blockDim;
}
}
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)