Cudnn support (#445)

* Add a cudnn feature to be used for conv2d.

* Allocate the proper workspace.

* Only create a single cudnn handle per cuda device.

* Proper cudnn usage.

* Bugfix.
This commit is contained in:
Laurent Mazare
2023-08-14 21:30:41 +01:00
committed by GitHub
parent c84883ecf2
commit 90374097dc
7 changed files with 195 additions and 12 deletions

View File

@ -47,6 +47,7 @@ anyhow = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]