diff --git a/Makefile b/Makefile index eba92821..cc0a0a36 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,8 @@ clean-ptx: find target -name "*.ptx" -type f -delete echo "" > candle-kernels/src/lib.rs touch candle-kernels/build.rs + touch candle-examples/build.rs + touch candle-flash-attn/build.rs clean: cargo clean diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index a7f63353..6129e100 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -897,7 +897,6 @@ impl<'a> Map2 for Conv1D<'a> { // Kernel shape: (c_out, c_in_k, k_size) // Input shape: (b_size, c_in, l_in) or (c_in, l_in) let p = &self.0; - let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(k_l.start_offset()..); let shape = inp_l.shape(); @@ -917,7 +916,44 @@ impl<'a> Map2 for Conv1D<'a> { panic!("unexpected input shape for conv1d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; - let params = (el, l_out, p.stride, &ds, inp, k, &out); + let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); +impl<'a> Map2 for Conv2D<'a> { + fn f( + &self, + inp: &CudaSlice, + inp_l: &Layout, + k: &CudaSlice, + k_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + // Kernel shape: (c_out, c_in_k, w_k, h_k) + // Input shape: (b_size, c_in, w_in, c_in) + let p = &self.0; + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(k_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let (out_w, out_h) = (p.out_w(), p.out_h()); + let dst_el = p.c_out * out_w * out_h * p.b_size; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::("conv2d"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let ds = if dims.len() == 4 { + [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() + } else { + panic!("unexpected input shape for conv1d {dims:?}") + }; + let ds = dev.htod_copy(ds).w()?; + let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) @@ -1383,12 +1419,14 @@ impl BackendStorage for CudaStorage { fn conv2d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &crate::conv::ParamsConv2D, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv2D, ) -> Result { - todo!() + let device = self.device().clone(); + let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + Ok(Self { slice, device }) } fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index f955b4a5..c777fec7 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -15,9 +15,7 @@ print(res.flatten()) res = torch.nn.functional.conv1d(t, w, padding=1) print(res.flatten()) */ -#[test] -fn conv1d() -> Result<()> { - let dev = &Device::Cpu; +fn conv1d(dev: &Device) -> Result<()> { let t = Tensor::new( &[ 0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145, @@ -51,9 +49,7 @@ fn conv1d() -> Result<()> { Ok(()) } -#[test] -fn conv1d_small() -> Result<()> { - let dev = &Device::Cpu; +fn conv1d_small(dev: &Device) -> Result<()> { let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?; let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?; let res = t.conv1d(&w, 0, 1)?; @@ -82,9 +78,7 @@ print(w.flatten()) res = torch.nn.functional.conv2d(t, w) print(res.flatten()) */ -#[test] -fn conv2d() -> Result<()> { - let dev = &Device::Cpu; +fn conv2d(dev: &Device) -> Result<()> { let t = Tensor::new( &[ 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, @@ -138,9 +132,7 @@ print(w.flatten()) res = torch.nn.functional.conv2d(t, w) print(res.flatten()) */ -#[test] -fn conv2d_small() -> Result<()> { - let dev = &Device::Cpu; +fn conv2d_small(dev: &Device) -> Result<()> { let t = Tensor::new( &[ 0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145, @@ -160,9 +152,7 @@ fn conv2d_small() -> Result<()> { Ok(()) } -#[test] -fn conv2d_smaller() -> Result<()> { - let dev = &Device::Cpu; +fn conv2d_smaller(dev: &Device) -> Result<()> { let t = Tensor::new( &[ 0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, @@ -180,3 +170,9 @@ fn conv2d_smaller() -> Result<()> { ); Ok(()) } + +test_device!(conv1d, conv1d_cpu, conv1d_gpu); +test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu); +test_device!(conv2d, conv2d_cpu, conv2d_gpu); +test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu); +test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu); diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 93ef56f3..722ca11e 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -1,11 +1,13 @@ #include "cuda_utils.cuh" #include +// Naive implementation of conv1d. template __device__ void conv1d( const size_t src_numel, const size_t l_out, - const size_t stride, + const size_t stride, + const size_t padding, const size_t *info, const T *src, const T *kernel, @@ -19,7 +21,6 @@ __device__ void conv1d( const size_t *k_s = info + 9; const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; const size_t k_size = k_dims[2]; - const size_t k_over_2 = k_size / 2; const size_t c_out = k_dims[0]; const size_t c_in = src_dims[1]; const size_t l_in = src_dims[2]; @@ -32,12 +33,73 @@ __device__ void conv1d( const size_t src_idx0 = b_idx * src_s[0]; A d = 0; for (size_t offset = 0; offset < k_size; ++offset) { - const size_t src_l_plus = stride * dst_l + offset; - if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) { - const size_t src_l = src_l_plus - k_over_2; + size_t src_l = stride * dst_l + offset; + if (src_l < padding || src_l >= padding + l_in) { + continue; + } + src_l -= padding; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2]; + const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2]; + d += static_cast(src[src_idx]) * static_cast(kernel[k_idx]); + } + } + dst[dst_i] = static_cast(d); +} + +// Naive implementation of conv2d. +template +__device__ void conv2d( + const size_t src_numel, + const size_t w_out, + const size_t h_out, + const size_t stride, + const size_t padding, + const size_t *info, + const T *src, + const T *kernel, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + if (dst_i >= src_numel) { + return; + } + // src: (b_size, c_in, w_in, h_in) + // k: (c_out, c_in, w_k, h_k) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + const size_t *k_dims = info + 8; + const size_t *k_s = info + 12; + const size_t w_k = k_dims[2]; + const size_t h_k = k_dims[3]; + const size_t c_out = k_dims[0]; + const size_t c_in = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + // TODO + const size_t b_idx = dst_i / (w_out * h_out * c_out); + const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out; + const size_t dst_w = (dst_i / h_out) % w_out; + const size_t dst_h = dst_i % h_out; + + const size_t src_idx0 = b_idx * src_s[0]; + A d = 0; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = stride * dst_w + w_offset; + if (src_w < padding || src_w >= w_in + padding) { + continue; + } + src_w -= padding; + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = stride * dst_h + h_offset; + if (src_h < padding || src_h >= h_in + padding) { + continue; + } + src_h -= padding; for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { - const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2]; - const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2]; + const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + w_offset * k_s[2] + h_offset * k_s[3]; d += static_cast(src[src_idx]) * static_cast(kernel[k_idx]); } } @@ -51,20 +113,38 @@ extern "C" __global__ void FN_NAME( \ const size_t src_numel, \ const size_t num_dims, \ const size_t stride, \ + const size_t padding, \ const size_t *info, \ const TYPENAME *src, \ const TYPENAME *kernel, \ TYPENAME *dst \ ) { \ - conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \ + conv1d(src_numel, num_dims, stride, padding, info, src, kernel, dst); \ +} \ + +#define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t w_out, \ + const size_t h_out, \ + const size_t stride, \ + const size_t padding, \ + const size_t *info, \ + const TYPENAME *src, \ + const TYPENAME *kernel, \ + TYPENAME *dst \ +) { \ + conv2d(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \ } \ #if __CUDA_ARCH__ >= 800 CONV1D_OP(__nv_bfloat16, float, conv1d_bf16) +CONV2D_OP(__nv_bfloat16, float, conv2d_bf16) #endif #if __CUDA_ARCH__ >= 530 CONV1D_OP(__half, float, conv1d_f16) +CONV2D_OP(__half, float, conv2d_f16) #endif CONV1D_OP(float, float, conv1d_f32) @@ -72,3 +152,8 @@ CONV1D_OP(double, double, conv1d_f64) CONV1D_OP(uint8_t, uint8_t, conv1d_u8) CONV1D_OP(uint32_t, uint32_t, conv1d_u32) +CONV2D_OP(float, float, conv2d_f32) +CONV2D_OP(double, double, conv2d_f64) +CONV2D_OP(uint8_t, uint8_t, conv2d_u8) +CONV2D_OP(uint32_t, uint32_t, conv2d_u32) +