From 04a61a9c72a1f13546c8b7becd95055129fda22f Mon Sep 17 00:00:00 2001 From: Thomas Santerre Date: Mon, 18 Mar 2024 13:50:14 -0400 Subject: [PATCH] Add avg_pool2d metal implementation for the metal backend (#1869) * implement metal avg pool 2d * fixX * add suggested precision workaround for the accumulator --- candle-core/src/metal_backend.rs | 46 +++++++++- candle-core/tests/pool_tests.rs | 3 - candle-metal-kernels/src/conv.metal | 69 ++++++++++++++ candle-metal-kernels/src/lib.rs | 2 +- candle-metal-kernels/src/tests.rs | 136 +++++++++++++++++++++++++--- 5 files changed, 236 insertions(+), 20 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index d77fbf3f..cf354f45 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1044,8 +1044,46 @@ impl BackendStorage for MetalStorage { crate::bail!("Metal conv_tranpose2d not implemented") } - fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - crate::bail!("Metal avg_pool2d not implemented") + fn avg_pool2d( + &self, + inp_l: &Layout, + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + ) -> Result { + let shape = inp_l.shape(); + let (b_size, channels, width, height) = shape.dims4()?; + let strides = inp_l.stride(); + let name = match self.dtype { + DType::F32 => "avg_pool2d_f32", + DType::F16 => "avg_pool2d_f16", + DType::BF16 => "avg_pool2d_bf16", + DType::U8 => "avg_pool2d_u8", + DType::U32 => "avg_pool2d_u32", + dtype => crate::bail!("Metal avg_pool2d {dtype:?} not implemented"), + }; + let out_w = (width - w_k) / w_stride + 1; + let out_h = (height - h_k) / h_stride + 1; + let dst_el = out_w * out_h * b_size * channels; + let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?; + let command_buffers = self.device.command_buffer()?; + candle_metal_kernels::call_pool2d( + &self.device.device, + &command_buffers, + &self.device.kernels, + name, + inp_l.dims(), + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &self.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } fn max_pool2d( @@ -1063,14 +1101,14 @@ impl BackendStorage for MetalStorage { DType::BF16 => "max_pool2d_bf16", DType::U8 => "max_pool2d_u8", DType::U32 => "max_pool2d_u32", - dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"), + dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"), }; let out_w = (width - w_k) / w_stride + 1; let out_h = (height - h_k) / h_stride + 1; let dst_el = out_w * out_h * b_size * channels; let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?; let command_buffers = self.device.command_buffer()?; - candle_metal_kernels::call_max_pool2d( + candle_metal_kernels::call_pool2d( &self.device.device, &command_buffers, &self.device.kernels, diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index 8e273fb9..1edb7d35 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -2,9 +2,6 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor}; // https://github.com/huggingface/candle/issues/364 fn avg_pool2d(dev: &Device) -> Result<()> { - if dev.is_metal() { - return Ok(()); - } let data: Vec = vec![ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., ]; diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index d7c23ddf..7f7a75cf 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -206,6 +206,67 @@ kernel void FN_NAME( \ upsample_nearest2d(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \ } \ +template +METAL_FUNC void avg_pool2d( + constant size_t &w_k, + constant size_t &h_k, + constant size_t &w_stride, + constant size_t &h_stride, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + A d = 0; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in){ + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; + d += static_cast(src[src_idx]); + } + } + dst[tid] = static_cast(d / (w_k * h_k)); +} + +#define AVGPOOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_k, \ + constant size_t &h_k, \ + constant size_t &w_s, \ + constant size_t &h_s, \ + constant size_t *src_dims, \ + constant size_t *src_s, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + avg_pool2d(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ +} \ + template METAL_FUNC void max_pool2d( constant size_t &w_k, @@ -292,4 +353,12 @@ MAXPOOL2D_OP(uint32_t, max_pool2d_u32) MAXPOOL2D_OP(uint8_t, max_pool2d_u8) #if defined(__HAVE_BFLOAT__) MAXPOOL2D_OP(bfloat, max_pool2d_bf16) +#endif + +AVGPOOL2D_OP(float, float, avg_pool2d_f32) +AVGPOOL2D_OP(half, float, avg_pool2d_f16) +AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32) +AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8) +#if defined(__HAVE_BFLOAT__) +AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16) #endif \ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index b1830a25..1161501f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1827,7 +1827,7 @@ fn divide(m: usize, b: usize) -> NSUInteger { } #[allow(clippy::too_many_arguments)] -pub fn call_max_pool2d( +pub fn call_pool2d( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 74721153..19e160dd 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1369,7 +1369,7 @@ fn index_add() { } } -fn run_max_pool2d( +fn run_pool2d( v: &[T], (w_k, h_k): (usize, usize), (w_stride, h_stride): (usize, usize), @@ -1386,7 +1386,7 @@ fn run_max_pool2d( let input = new_buffer(&device, v); let output = new_buffer(&device, &vec![0.0f32; dst_el]); let kernels = Kernels::new(); - call_max_pool2d( + call_pool2d( &device, command_buffer, &kernels, @@ -1417,7 +1417,7 @@ fn max_pool2d_f32() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1434,7 +1434,7 @@ fn max_pool2d_f32() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1454,7 +1454,7 @@ fn max_pool2d_f16() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1474,7 +1474,7 @@ fn max_pool2d_f16() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1497,7 +1497,7 @@ fn max_pool2d_bf16() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1517,7 +1517,7 @@ fn max_pool2d_bf16() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1540,7 +1540,7 @@ fn max_pool2d_u8() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1557,7 +1557,7 @@ fn max_pool2d_u8() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1577,7 +1577,7 @@ fn max_pool2d_u32() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1594,7 +1594,7 @@ fn max_pool2d_u32() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1605,3 +1605,115 @@ fn max_pool2d_u32() { let expected = vec![5, 7, 13, 15]; assert_eq!(results, expected); } + +#[test] +fn avg_pool2d_f32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_f32", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_f16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_f16", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_bf16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_bf16", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_u8() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_u8", + ); + let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_u32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_u32", + ); + let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; + assert_eq!(results, expected); +}