mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Cudnn fix. (#758)
This commit is contained in:
@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.i_w as i32,
|
||||
params.i_h as i32,
|
||||
params.i_w as i32,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_w as i32,
|
||||
params.k_h as i32,
|
||||
params.k_w as i32,
|
||||
],
|
||||
)?;
|
||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||
let y = cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, w_out, h_out],
|
||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||
)?;
|
||||
let conv2d = Conv2dForward {
|
||||
conv: &conv,
|
||||
|
Reference in New Issue
Block a user