diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 3bee7657..d77fbf3f 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -263,7 +263,7 @@ impl MetalDevice { } } } - return best_buffer.map(|b| b.clone()); + best_buffer.cloned() } fn drop_unused_buffers(&self) -> Result<()> { @@ -1048,8 +1048,46 @@ impl BackendStorage for MetalStorage { crate::bail!("Metal avg_pool2d not implemented") } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - crate::bail!("Metal max_pool2d not implemented") + fn max_pool2d( + &self, + inp_l: &Layout, + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + ) -> Result { + let shape = inp_l.shape(); + let (b_size, channels, width, height) = shape.dims4()?; + let strides = inp_l.stride(); + let name = match self.dtype { + DType::F32 => "max_pool2d_f32", + DType::F16 => "max_pool2d_f16", + DType::BF16 => "max_pool2d_bf16", + DType::U8 => "max_pool2d_u8", + DType::U32 => "max_pool2d_u32", + dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"), + }; + let out_w = (width - w_k) / w_stride + 1; + let out_h = (height - h_k) / h_stride + 1; + let dst_el = out_w * out_h * b_size * channels; + let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?; + let command_buffers = self.device.command_buffer()?; + candle_metal_kernels::call_max_pool2d( + &self.device.device, + &command_buffers, + &self.device.kernels, + name, + inp_l.dims(), + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &self.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index a6530e03..8e273fb9 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -22,9 +22,6 @@ fn avg_pool2d(dev: &Device) -> Result<()> { } fn max_pool2d(dev: &Device) -> Result<()> { - if dev.is_metal() { - return Ok(()); - } let data: Vec = vec![ 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1., ]; diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index dca53161..d7c23ddf 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -1,3 +1,9 @@ +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + template METAL_FUNC void im2col( constant size_t &dst_numel, @@ -200,6 +206,74 @@ kernel void FN_NAME( \ upsample_nearest2d(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \ } \ +template +METAL_FUNC void max_pool2d( + constant size_t &w_k, + constant size_t &h_k, + constant size_t &w_stride, + constant size_t &h_stride, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + T d = 0; + bool set = false; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in){ + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; + if (set) { + d = MAX(d, src[src_idx]); + } + else { + d = src[src_idx]; + set = true; + } + } + } + dst[tid] = d; +} + +#define MAXPOOL2D_OP(TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_k, \ + constant size_t &h_k, \ + constant size_t &w_s, \ + constant size_t &h_s, \ + constant size_t *src_dims, \ + constant size_t *src_s, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + max_pool2d(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ +} \ + IM2COL_OP(float, im2col_f32) IM2COL_OP(uint8_t, im2col_u8) IM2COL_OP(uint32_t, im2col_u32) @@ -211,3 +285,11 @@ IM2COL1D_OP(uint32_t, im2col1d_u32) UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) + +MAXPOOL2D_OP(float, max_pool2d_f32) +MAXPOOL2D_OP(half, max_pool2d_f16) +MAXPOOL2D_OP(uint32_t, max_pool2d_u32) +MAXPOOL2D_OP(uint8_t, max_pool2d_u8) +#if defined(__HAVE_BFLOAT__) +MAXPOOL2D_OP(bfloat, max_pool2d_bf16) +#endif \ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a879c86a..b1830a25 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1826,5 +1826,38 @@ fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } +#[allow(clippy::too_many_arguments)] +pub fn call_max_pool2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + out_w: usize, + out_h: usize, + w_k: usize, + h_k: usize, + w_stride: usize, + h_stride: usize, + input: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = out_w * out_h * shape[0] * shape[1]; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (w_k, h_k, w_stride, h_stride, shape, strides, input, output) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index a34882d3..74721153 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,6 @@ use super::*; use half::{bf16, f16}; -use metal::{Buffer, Device, MTLResourceOptions}; +use metal::MTLResourceOptions; fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let ptr = buffer.contents() as *const T; @@ -1368,3 +1368,240 @@ fn index_add() { assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); } } + +fn run_max_pool2d( + v: &[T], + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + shape: &[usize], + strides: &[usize], + name: &'static str, +) -> Vec { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let out_w = (shape[2] - w_k) / w_stride + 1; + let out_h = (shape[3] - h_k) / h_stride + 1; + let dst_el = out_w * out_h * shape[0] * shape[1]; + let input = new_buffer(&device, v); + let output = new_buffer(&device, &vec![0.0f32; dst_el]); + let kernels = Kernels::new(); + call_max_pool2d( + &device, + command_buffer, + &kernels, + name, + shape, + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &input, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, dst_el) +} + +#[test] +fn max_pool2d_f32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f32", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f32", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0]; + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_f16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| half::f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f16", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + .iter() + .map(|v| half::f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| half::f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f16", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0] + .iter() + .map(|v| half::f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_bf16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_bf16", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + .iter() + .map(|v| half::bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_bf16", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0] + .iter() + .map(|v| half::bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_u8() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u8", + ); + let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u8", + ); + let expected = vec![5, 7, 13, 15]; + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_u32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u32", + ); + let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u32", + ); + let expected = vec![5, 7, 13, 15]; + assert_eq!(results, expected); +}