Cudnn fix. (#758)

This commit is contained in:
Laurent Mazare
2023-09-06 18:39:39 +02:00
committed by GitHub
parent bdc9d46fe3
commit 7b1f2da828

View File

@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.i_w as i32,
params.i_h as i32,
params.i_w as i32,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
[
params.c_out as i32,
params.c_in as i32,
params.k_w as i32,
params.k_h as i32,
params.k_w as i32,
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, w_out, h_out],
[params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
let conv2d = Conv2dForward {
conv: &conv,