Cudnn fix. (#758)

This commit is contained in:
Laurent Mazare
2023-09-06 18:39:39 +02:00
committed by GitHub
parent bdc9d46fe3
commit 7b1f2da828

View File

@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
let x_shape = [ let x_shape = [
params.b_size as i32, params.b_size as i32,
params.c_in as i32, params.c_in as i32,
params.i_w as i32,
params.i_h as i32, params.i_h as i32,
params.i_w as i32,
]; ];
// Note that `src` already starts at the proper offset. // Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() { let x = if src_l.is_contiguous() {
@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
[ [
params.c_out as i32, params.c_out as i32,
params.c_in as i32, params.c_in as i32,
params.k_w as i32,
params.k_h as i32, params.k_h as i32,
params.k_w as i32,
], ],
)?; )?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32); let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor( let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, w_out, h_out], [params.b_size as i32, params.c_out as i32, h_out, w_out],
)?; )?;
let conv2d = Conv2dForward { let conv2d = Conv2dForward {
conv: &conv, conv: &conv,