From 933716b37485ffb04a1af476736a7e6529bc7b89 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 22 Jan 2024 20:59:02 +0100 Subject: [PATCH] Where cond get_strided_index conditionally based on function constants --- candle-core/src/metal_backend.rs | 3 +++ candle-metal-kernels/src/lib.rs | 11 ++++++++- candle-metal-kernels/src/ternary.metal | 33 +++++++++++++++++++------- candle-metal-kernels/src/tests.rs | 3 +++ 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index ebcad786..7d8cded7 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -822,10 +822,13 @@ impl BackendStorage for MetalStorage { layout.stride(), layout.start_offset() * self.dtype.size_in_bytes(), ), + !layout.is_contiguous(), &t.buffer, (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), + !t_l.is_contiguous(), &f.buffer, (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), + !f_l.is_contiguous(), &buffer, ) .map_err(MetalError::from)?; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index fe969372..ec7484df 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -909,13 +909,22 @@ pub fn call_where_cond_strided( shape: &[usize], cond: &Buffer, (cond_stride, cond_offset): (&[usize], usize), + cond_is_strided: bool, left: &Buffer, (left_stride, left_offset): (&[usize], usize), + left_is_strided: bool, right: &Buffer, (right_stride, right_offset): (&[usize], usize), + right_is_strided: bool, output: &Buffer, ) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; + let constants = Some(ConstantValues::new(vec![ + (0, Value::Bool(cond_is_strided)), + (1, Value::Bool(left_is_strided)), + (2, Value::Bool(right_is_strided)), + ])); + let pipeline = + kernels.load_pipeline_with_constants(device, Source::Ternary, name, constants)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 7b3b8ca9..b5125498 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -1,14 +1,20 @@ #include -# + using namespace metal; +constant bool IDS_STRIDED [[function_constant(0)]]; +constant bool T_STRIDED [[function_constant(1)]]; +constant bool F_STRIDED [[function_constant(2)]]; + + METAL_FUNC uint get_strided_index( uint idx, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides + constant const size_t &num_dims, + constant const size_t *dims, + constant const size_t *strides ) { uint strided_i = 0; + #pragma clang loop unroll(full) for (uint d = 0; d < num_dims; d++) { uint dim_idx = num_dims - 1 - d; strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; @@ -17,6 +23,7 @@ METAL_FUNC uint get_strided_index( return strided_i; } + template METAL_FUNC void where_cond( constant size_t &numel, @@ -34,10 +41,20 @@ METAL_FUNC void where_cond( if (i >= numel){ return; } - uint strided_i = get_strided_index(i, num_dims, dims, strides); - uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); - uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); - out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; + uint strided_i = i; + uint strided_i_t = i; + uint strided_i_f = i; + if (IDS_STRIDED) { + strided_i = get_strided_index(i, num_dims, dims, strides); + } + if (T_STRIDED) { + strided_i_t = get_strided_index(i, num_dims, dims, strides_t); + } + if (F_STRIDED) { + strided_i_f = get_strided_index(i, num_dims, dims, strides_f); + } + + out[i] = select(f[strided_i_t], t[strided_i_f], ids[strided_i]); } #define WHERE_OP(T, ID, FN_NAME) \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 655161e5..804ed500 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -803,10 +803,13 @@ fn run_where_cond( shape, &cond, (&cond_stride, cond_offset), + true, &left, (&left_stride, left_offset), + true, &right, (&cond_stride, cond_offset), + true, &output, ) .unwrap();