From a094dc503d69a6ca3db71098ebc26d0d2f2a33a6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 14 Aug 2023 12:32:05 +0100 Subject: [PATCH] Add a cuda kernel for avg-pool2d. (#440) * Add a cuda kernel for avg-pool2d. * Avoid running out of bounds. * Finish wiring the avg pool kernel + add some testing. * Support for max-pool + testing. --- candle-core/src/cuda_backend.rs | 84 ++++++++++++++++- candle-core/tests/pool_tests.rs | 28 +++--- candle-kernels/src/conv.cu | 160 +++++++++++++++++++++++++++++++- 3 files changed, 253 insertions(+), 19 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 6129e100..90d3ee6d 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -960,6 +960,64 @@ impl<'a> Map2 for Conv2D<'a> { } } +enum PoolOp { + Max, + Avg, +} + +struct Pool2D { + w_k: usize, + h_k: usize, + w_stride: usize, + h_stride: usize, + op: PoolOp, +} + +impl Map1 for Pool2D { + fn f( + &self, + inp: &CudaSlice, + dev: &CudaDevice, + inp_l: &Layout, + ) -> Result> { + // Kernel shape: (c_out, c_in_k, w_k, h_k) + let inp = &inp.slice(inp_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let ds = if dims.len() == 4 { + [dims, inp_l.stride()].concat() + } else { + panic!("unexpected input shape for conv1d {dims:?}") + }; + let el = shape.elem_count(); + let out_w = (dims[2] - self.w_k) / self.w_stride + 1; + let out_h = (dims[3] - self.h_k) / self.h_stride + 1; + let dst_el = out_w * out_h * dims[0] * dims[1]; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let kname = match self.op { + PoolOp::Max => "max_pool2d", + PoolOp::Avg => "avg_pool2d", + }; + let func = dev.get_or_load_func(&kernel_name::(kname), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let ds = dev.htod_copy(ds).w()?; + let params = ( + el, + self.w_k, + self.h_k, + self.w_stride, + self.h_stride, + &ds, + inp, + &out, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + struct WhereCond<'a>(&'a CudaStorage, &'a Layout); impl<'a> Map2 for WhereCond<'a> { fn f( @@ -1429,12 +1487,30 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - todo!() + fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result { + let device = self.device().clone(); + let slice = Pool2D { + w_k: k.0, + h_k: k.1, + w_stride: stride.0, + h_stride: stride.1, + op: PoolOp::Avg, + } + .map(&self.slice, &device, l)?; + Ok(Self { slice, device }) } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - todo!() + fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result { + let device = self.device().clone(); + let slice = Pool2D { + w_k: k.0, + h_k: k.1, + w_stride: stride.0, + h_stride: stride.1, + op: PoolOp::Max, + } + .map(&self.slice, &device, l)?; + Ok(Self { slice, device }) } fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index 73bf7434..009564fa 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -1,25 +1,22 @@ mod test_utils; -use candle_core::{Device, IndexOp, Tensor}; +use candle_core::{Device, IndexOp, Result, Tensor}; // https://github.com/huggingface/candle/issues/364 -#[test] -fn avg_pool2d() -> anyhow::Result<()> { +fn avg_pool2d(dev: &Device) -> Result<()> { let data: Vec = vec![ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., ]; - let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?; - + let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?; let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; assert_eq!(pool.to_vec2::()?, [[0.5f32, 1.], [1., 1.]]); Ok(()) } -#[test] -fn max_pool2d() -> anyhow::Result<()> { +fn max_pool2d(dev: &Device) -> Result<()> { let data: Vec = vec![ 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1., ]; - let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?; + let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?; let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; assert_eq!(pool.to_vec2::()?, [[2f32, 3.], [5., 1.]]); @@ -35,8 +32,7 @@ print(t.flatten()) res = torch.nn.functional.avg_pool2d(t, 2) print(res) */ -#[test] -fn avg_pool2d_pytorch() -> anyhow::Result<()> { +fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { let t = Tensor::new( &[ 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, @@ -44,7 +40,7 @@ fn avg_pool2d_pytorch() -> anyhow::Result<()> { 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836, 0.2477, 1.3127, ], - &Device::Cpu, + dev, )? .reshape((1, 2, 4, 4))?; let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?; @@ -61,7 +57,7 @@ fn avg_pool2d_pytorch() -> anyhow::Result<()> { } #[test] -fn upsample_nearest2d() -> anyhow::Result<()> { +fn upsample_nearest2d() -> Result<()> { let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?; let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?; assert_eq!( @@ -79,3 +75,11 @@ fn upsample_nearest2d() -> anyhow::Result<()> { ); Ok(()) } + +test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu); +test_device!( + avg_pool2d_pytorch, + avg_pool2d_pytorch_cpu, + avg_pool2d_pytorch_gpu +); +test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu); diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 722ca11e..2da4d401 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -24,6 +24,9 @@ __device__ void conv1d( const size_t c_out = k_dims[0]; const size_t c_in = src_dims[1]; const size_t l_in = src_dims[2]; + if (dst_i >= src_dims[0] * c_out * l_out) { + return; + } // TODO const size_t b_idx = dst_i / (l_out * c_out); @@ -61,9 +64,6 @@ __device__ void conv2d( T *dst ) { const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; - if (dst_i >= src_numel) { - return; - } // src: (b_size, c_in, w_in, h_in) // k: (c_out, c_in, w_k, h_k) const size_t *src_dims = info; @@ -76,6 +76,9 @@ __device__ void conv2d( const size_t c_in = src_dims[1]; const size_t w_in = src_dims[2]; const size_t h_in = src_dims[3]; + if (dst_i >= src_dims[0] * c_out * w_out * h_out) { + return; + } // TODO const size_t b_idx = dst_i / (w_out * h_out * c_out); @@ -107,6 +110,116 @@ __device__ void conv2d( dst[dst_i] = static_cast(d); } +template +__device__ void avg_pool2d( + const size_t src_numel, + const size_t w_k, + const size_t h_k, + const size_t w_stride, + const size_t h_stride, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, w_in, h_in) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (dst_i >= src_dims[0] * c * w_out * h_out) { + return; + } + + // TODO: Improve this. + const size_t b_idx = dst_i / (w_out * h_out * c); + const size_t c_idx = (dst_i / (w_out * h_out)) % c; + const size_t dst_w = (dst_i / h_out) % w_out; + const size_t dst_h = dst_i % h_out; + + const size_t src_idx0 = b_idx * src_s[0]; + const float scale = 1.0 / (w_k * h_k); + A d = 0; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in) { + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + d += static_cast(src[src_idx]); + } + } + dst[dst_i] = static_cast(d * scale); +} + +template +__device__ void max_pool2d( + const size_t src_numel, + const size_t w_k, + const size_t h_k, + const size_t w_stride, + const size_t h_stride, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, w_in, h_in) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (dst_i >= src_dims[0] * c * w_out * h_out) { + return; + } + + // TODO: Improve this. + const size_t b_idx = dst_i / (w_out * h_out * c); + const size_t c_idx = (dst_i / (w_out * h_out)) % c; + const size_t dst_w = (dst_i / h_out) % w_out; + const size_t dst_h = dst_i % h_out; + + const size_t src_idx0 = b_idx * src_s[0]; + T d = 0; + bool set = false; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in) { + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + if (set) { + d = maxg(d, src[src_idx]); + } + else { + d = src[src_idx]; + set = true; + } + } + } + dst[dst_i] = d; +} + #define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \ extern "C" __global__ void FN_NAME( \ @@ -137,14 +250,46 @@ extern "C" __global__ void FN_NAME( \ conv2d(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \ } \ +#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t w_k, \ + const size_t h_k, \ + const size_t w_stride, \ + const size_t h_stride, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + avg_pool2d(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \ +} \ + +#define MAX_POOL2D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t w_k, \ + const size_t h_k, \ + const size_t w_stride, \ + const size_t h_stride, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + max_pool2d(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \ +} \ + #if __CUDA_ARCH__ >= 800 CONV1D_OP(__nv_bfloat16, float, conv1d_bf16) CONV2D_OP(__nv_bfloat16, float, conv2d_bf16) +AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16) +MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) #endif #if __CUDA_ARCH__ >= 530 CONV1D_OP(__half, float, conv1d_f16) CONV2D_OP(__half, float, conv2d_f16) +AVG_POOL2D_OP(__half, float, avg_pool2d_f16) +MAX_POOL2D_OP(__half, max_pool2d_f16) #endif CONV1D_OP(float, float, conv1d_f32) @@ -157,3 +302,12 @@ CONV2D_OP(double, double, conv2d_f64) CONV2D_OP(uint8_t, uint8_t, conv2d_u8) CONV2D_OP(uint32_t, uint32_t, conv2d_u32) +AVG_POOL2D_OP(float, float, avg_pool2d_f32) +AVG_POOL2D_OP(double, double, avg_pool2d_f64) +AVG_POOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8) +AVG_POOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32) + +MAX_POOL2D_OP(float, max_pool2d_f32) +MAX_POOL2D_OP(double, max_pool2d_f64) +MAX_POOL2D_OP(uint8_t, max_pool2d_u8) +MAX_POOL2D_OP(uint32_t, max_pool2d_u32)