From fb660b8d430658ff434eee96515cc5dadcf973a1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 13 Apr 2025 17:43:41 +0200 Subject: [PATCH] Add a cudnn feature to candle-nn/candle-transformers. (#2890) --- candle-examples/Cargo.toml | 2 +- candle-flash-attn/Cargo.toml | 5 ++++- candle-nn/Cargo.toml | 1 + candle-transformers/Cargo.toml | 1 + 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 6633ec50..0d5f3cb6 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -60,7 +60,7 @@ bindgen_cuda = { version = "0.1.1", optional = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"] -cudnn = ["candle/cudnn"] +cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] nccl = ["cuda", "cudarc/nccl", "dep:half"] diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index c0189b12..296c74e5 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -18,7 +18,10 @@ half = { version = "2.3.1", features = ["num-traits"] } bindgen_cuda = "0.1.1" anyhow = { version = "1", features = ["backtrace"] } - [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } candle-nn = { path = "../candle-nn", features = ["cuda"] } + +[features] +default = [] +cudnn = ["candle/cudnn"] diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index e62f4c32..547e2045 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -33,6 +33,7 @@ criterion = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] +cudnn = ["candle/cudnn"] mkl = ["dep:intel-mkl-src", "candle/mkl"] metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 6589b4b1..fe0beefb 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -29,6 +29,7 @@ tracing = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] +cudnn = ["candle/cudnn", "candle-nn/cudnn"] flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] metal = ["candle/metal", "candle-nn/metal"]