mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Cuda conv transpose (#645)
* Cuda kernel for conv-transpose. * Fix the cuda kernel. * Fix the tests.
This commit is contained in:
@ -977,8 +977,8 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_out, c_in_k, w_k, h_k)
|
||||
// Input shape: (b_size, c_in, w_in, c_in)
|
||||
// Kernel shape: (c_out, c_in_k, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
@ -1005,6 +1005,55 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
out_w,
|
||||
out_h,
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
&out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
enum PoolOp {
|
||||
Max,
|
||||
Avg,
|
||||
@ -1649,12 +1698,15 @@ impl BackendStorage for CudaStorage {
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
|
@ -122,7 +122,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
@ -148,7 +147,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -202,8 +200,6 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
]
|
||||
);
|
||||
// TODO: enable the test for cuda once we have the proper implementation in place.
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
@ -215,13 +211,12 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
-0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728,
|
||||
0.528, -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838,
|
||||
0.5802, -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396,
|
||||
-0.8156, 0.4594, 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
|
||||
-0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528,
|
||||
-1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802,
|
||||
-0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594,
|
||||
2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
|
||||
]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -275,10 +270,8 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_grad() -> Result<()> {
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
use candle_core::Var;
|
||||
let dev = &Device::Cpu;
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
@ -318,32 +311,28 @@ fn conv2d_grad() -> Result<()> {
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_t.flatten_all()?, 4)?,
|
||||
test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?,
|
||||
[
|
||||
9.2868, -2.8352, -5.7117, 3.3817, -7.7094, -19.1549, 7.016, 29.1037, 9.3411, 34.7339,
|
||||
-22.8726, 24.3502, -39.88, -14.007, 21.076, 9.9419, 13.6333, -34.6796, 11.2073,
|
||||
-6.2617, 7.7209, -6.3224, -16.6373, -1.0837, -20.2215, 21.7302, -0.3744, -4.0573,
|
||||
5.8163, -3.6529, -30.7319, 14.5468, 87.699, 31.6035, 4.5304, -89.785, -75.3709,
|
||||
-57.4327, -7.5602, 92.9585, 18.791, -4.6311, -159.7521, -42.4656, -47.2644, 52.8768,
|
||||
37.3172, 48.9978, 12.8192, 2.014, -8.9826, 20.1759, 16.621, 12.0599, 15.3849, 19.9979,
|
||||
2.5725, -15.2197, 72.6244, -10.7496, 2.2541, -31.2003, 3.753, -0.2049, 9.7574, -0.6824,
|
||||
5.2107, -40.4361, -22.5891, -61.6085, 17.2837, 20.4149, 37.5454, 5.2262, 6.8126,
|
||||
23.5361, 23.6173, -9.9866, -9.1324, 4.8664, -35.0617, -26.1023, 63.4757, 25.8144,
|
||||
-39.2069, -70.6834, -46.9565, 2.3252, 41.8093, 82.4205, -28.626, -11.7812, -35.3284,
|
||||
-10.2771, -28.5694, -9.1258, 7.213, -9.0459, -9.6222, -11.2544
|
||||
9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35,
|
||||
-39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08,
|
||||
-20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78,
|
||||
-75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32,
|
||||
49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62,
|
||||
-10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28,
|
||||
20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48,
|
||||
25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28,
|
||||
-28.57, -9.13, 7.21, -9.05, -9.62, -11.25
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 4)?,
|
||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?,
|
||||
[
|
||||
-28.9232, -22.8833, -141.2296, 73.3462, 61.074, 47.8125, -20.0013, -73.7086, -41.8217,
|
||||
-13.5919, 21.501, 28.7179, 28.5683, -46.8486, -90.1874, 143.6107, 16.6764, 7.4259,
|
||||
18.8794, -90.8122, -20.2865, 54.7909, 82.6287, 22.943, 77.8084, -16.3928, -13.1977,
|
||||
9.3442, -40.3869, -26.6153, 5.3344, -60.9081, 9.0869, -59.368, 7.081, 58.6391, 5.5476,
|
||||
20.5152, 2.4985, -17.2466, -6.802, 22.2146, 30.1511, -7.5179, -37.4588, 5.6654,
|
||||
22.5832, 9.0316, 47.0547, 17.6123, 37.3121, -98.1295, -14.6141, -4.7958, -6.3597,
|
||||
44.6949, 23.3418, 8.3728, -13.52, 80.0522, -34.2403, -16.3648, -12.3139, 1.9195,
|
||||
-33.6244, -14.102, -49.2305, -7.3853, 11.4995, -9.9826, 9.6588, 29.6042
|
||||
-28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5,
|
||||
28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63,
|
||||
22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08,
|
||||
58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03,
|
||||
47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05,
|
||||
-34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
@ -359,3 +348,4 @@ test_device!(
|
||||
);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
||||
|
@ -111,6 +111,71 @@ __device__ void conv2d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
// Naive implementation of conv_transpose2d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv_transpose2d(
|
||||
const size_t src_numel,
|
||||
const size_t w_out,
|
||||
const size_t h_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t out_padding,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, c_in, h_in, w_in)
|
||||
// k: (c_in, c_out, h_k, w_k)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
const size_t *k_dims = info + 8;
|
||||
const size_t *k_s = info + 12;
|
||||
const size_t h_k = k_dims[2];
|
||||
const size_t w_k = k_dims[3];
|
||||
const size_t c_out = k_dims[1];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t h_in = src_dims[2];
|
||||
const size_t w_in = src_dims[3];
|
||||
if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO
|
||||
const size_t b_idx = dst_i / (w_out * h_out * c_out);
|
||||
const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out;
|
||||
// NCHW layout.
|
||||
const size_t out_y = (dst_i / w_out) % h_out;
|
||||
const size_t out_x = dst_i % w_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
|
||||
// let out_x = inp_x * p.stride + k_x - p.padding;
|
||||
int inp_x_stride = (int)(out_x + padding) - k_x;
|
||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
int inp_x = inp_x_stride / stride;
|
||||
if (inp_x >= w_in) continue;
|
||||
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
|
||||
int inp_y_stride = (int)(out_y + padding) - k_y;
|
||||
if (inp_y_stride < 0 || inp_y_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
int inp_y = inp_y_stride / stride;
|
||||
if (inp_y >= h_in) continue;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_y * src_s[2] + inp_x * src_s[3];
|
||||
const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
template <typename T, typename A>
|
||||
__device__ void avg_pool2d(
|
||||
const size_t src_numel,
|
||||
@ -293,6 +358,22 @@ extern "C" __global__ void FN_NAME( \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t w_out, \
|
||||
const size_t h_out, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t out_padding, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
@ -337,6 +418,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
|
||||
CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
|
||||
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
@ -345,6 +427,7 @@ UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CONV1D_OP(__half, float, conv1d_f16)
|
||||
CONV2D_OP(__half, float, conv2d_f16)
|
||||
CONVT2D_OP(__half, float, conv_transpose2d_f16)
|
||||
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
@ -360,6 +443,11 @@ CONV2D_OP(double, double, conv2d_f64)
|
||||
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
|
||||
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
|
||||
|
||||
CONVT2D_OP(float, float, conv_transpose2d_f32)
|
||||
CONVT2D_OP(double, double, conv_transpose2d_f64)
|
||||
CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)
|
||||
CONVT2D_OP(uint32_t, uint32_t, conv_transpose2d_u32)
|
||||
|
||||
AVG_POOL2D_OP(float, float, avg_pool2d_f32)
|
||||
AVG_POOL2D_OP(double, double, avg_pool2d_f64)
|
||||
AVG_POOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
|
||||
|
Reference in New Issue
Block a user