Cudnn support (#445)

* Add a cudnn feature to be used for conv2d.

* Allocate the proper workspace.

* Only create a single cudnn handle per cuda device.

* Proper cudnn usage.

* Bugfix.
This commit is contained in:
Laurent Mazare
2023-08-14 21:30:41 +01:00
committed by GitHub
parent c84883ecf2
commit 90374097dc
7 changed files with 195 additions and 12 deletions

View File

@ -9,10 +9,9 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
let sum = t.sum_keepdim(0)?;
println!("{sum}");
let sum = t.sum_keepdim(1)?;
println!("{sum}");
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1)?;
println!("{res:?}");
Ok(())
}