diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 03e6d810..b3116c86 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -172,8 +172,8 @@ impl BackendStorage for MetalStorage { .unwrap(); } else { let name = match self.dtype { - DType::F32 => "affine_float", - DType::F16 => "affine_half", + DType::F32 => "affine_float_strided", + DType::F16 => "affine_half_strided", dtype => todo!("Affine {dtype:?}"), }; candle_metal_kernels::call_affine_strided( @@ -829,7 +829,6 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let el_count = src_shape.elem_count(); - // todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}"); if el_count == 0 { return Ok(()); } @@ -851,7 +850,7 @@ impl BackendStorage for MetalStorage { &src_l.stride(), src_l.start_offset() * self.dtype.size_in_bytes(), &mut dst.buffer, - dst_offset, + dst_offset * dst.dtype.size_in_bytes(), ) .map_err(MetalError::from)?; // command_buffer.commit(); diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index e0129ca9..007f9fed 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -16,16 +16,16 @@ kernel void NAME( \ if (gid >= dst_size) { \ return; \ } \ - const size_t id_i = gid / right_size / left_size; \ + const size_t id_i = (gid / right_size) % ids_size; \ + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ const size_t right_rank_i = gid % right_size; \ - const size_t left_rank_i = gid % left_size; \ + const size_t left_rank_i = gid / right_size / ids_size; \ /* \ // Force prevent out of bounds indexing \ // since there doesn't seem to be a good way to force crash \ // No need to check for zero we're only allowing unsized. \ */ \ - const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ - const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \ output[gid] = input[src_i]; \ } @@ -75,7 +75,6 @@ kernel void FN_NAME( \ INDEX_OP(is_u32_f32, uint, float) -INDEX_OP(is_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c4a0ca97..2cadc8c6 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -112,7 +112,7 @@ macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { - pub struct Kernel(pub(crate) &'static str); + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; @@ -131,7 +131,7 @@ macro_rules! ops{ } pub mod strided { - pub struct Kernel(pub(crate) &'static str); + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; @@ -172,7 +172,7 @@ pub enum MetalKernelError { LockError(String), #[error("Error while loading library: {0}")] LoadLibraryError(String), - #[error("Error while loading function: {0}")] + #[error("Error while loading function: {0:?}")] LoadFunctionError(String), } @@ -1053,6 +1053,7 @@ mod tests { &device, command_buffer, &kernels, + "affine_float", size, &input, &mut output, @@ -1087,7 +1088,7 @@ mod tests { &device, command_buffer, &kernels, - size, + "affine_float", shape, &input, strides, @@ -1146,7 +1147,10 @@ mod tests { result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] ); + } + #[test] + fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; @@ -1154,7 +1158,7 @@ mod tests { let result = run_index_select(&embedding, &shape, &ids, dim); assert_eq!( result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] ); } diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/tmp/affine.rs similarity index 98% rename from candle-metal-kernels/examples/affine.rs rename to candle-metal-kernels/tmp/affine.rs index b8005dc0..cd019056 100644 --- a/candle-metal-kernels/examples/affine.rs +++ b/candle-metal-kernels/tmp/affine.rs @@ -50,6 +50,7 @@ fn run_affine_bench(device: &Device, kernels: &Kernels, v: &[T]) { &device, command_buffer, &kernels, + "affine_float", v.len(), &input, &mut output, diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/tmp/binary.rs similarity index 100% rename from candle-metal-kernels/examples/binary.rs rename to candle-metal-kernels/tmp/binary.rs diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/tmp/cast.rs similarity index 100% rename from candle-metal-kernels/examples/cast.rs rename to candle-metal-kernels/tmp/cast.rs diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/tmp/unary.rs similarity index 98% rename from candle-metal-kernels/examples/unary.rs rename to candle-metal-kernels/tmp/unary.rs index 7039c098..66cf25c0 100644 --- a/candle-metal-kernels/examples/unary.rs +++ b/candle-metal-kernels/tmp/unary.rs @@ -147,7 +147,7 @@ fn run_unary_bench( println!( "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", type_name::().split("::").last().unwrap(), - kernel_name.to_string(), + kernel_name.0, v.len(), iterations, total_time, @@ -159,7 +159,7 @@ fn run_unary_bench( let shape = vec![2, 5_000]; let strides = vec![2, 1]; let offset = 0; - for kernel_name in strided { + for kernel_name in &strided { let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); let start = Instant::now(); @@ -187,7 +187,7 @@ fn run_unary_bench( println!( "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", type_name::().split("::").last().unwrap(), - kernel_name.to_string(), + kernel_name.0, v.len(), iterations, total_time,