mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Cudnn fix. (#758)
This commit is contained in:
@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
|
|||||||
let x_shape = [
|
let x_shape = [
|
||||||
params.b_size as i32,
|
params.b_size as i32,
|
||||||
params.c_in as i32,
|
params.c_in as i32,
|
||||||
params.i_w as i32,
|
|
||||||
params.i_h as i32,
|
params.i_h as i32,
|
||||||
|
params.i_w as i32,
|
||||||
];
|
];
|
||||||
// Note that `src` already starts at the proper offset.
|
// Note that `src` already starts at the proper offset.
|
||||||
let x = if src_l.is_contiguous() {
|
let x = if src_l.is_contiguous() {
|
||||||
@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
|
|||||||
[
|
[
|
||||||
params.c_out as i32,
|
params.c_out as i32,
|
||||||
params.c_in as i32,
|
params.c_in as i32,
|
||||||
params.k_w as i32,
|
|
||||||
params.k_h as i32,
|
params.k_h as i32,
|
||||||
|
params.k_w as i32,
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||||
let y = cudnn.create_4d_tensor(
|
let y = cudnn.create_4d_tensor(
|
||||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
[params.b_size as i32, params.c_out as i32, w_out, h_out],
|
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||||
)?;
|
)?;
|
||||||
let conv2d = Conv2dForward {
|
let conv2d = Conv2dForward {
|
||||||
conv: &conv,
|
conv: &conv,
|
||||||
|
Reference in New Issue
Block a user