diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 6633ec50..0d5f3cb6 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -60,7 +60,7 @@ bindgen_cuda = { version = "0.1.1", optional = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"] -cudnn = ["candle/cudnn"] +cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] nccl = ["cuda", "cudarc/nccl", "dep:half"] diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index c0189b12..296c74e5 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -18,7 +18,10 @@ half = { version = "2.3.1", features = ["num-traits"] } bindgen_cuda = "0.1.1" anyhow = { version = "1", features = ["backtrace"] } - [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } candle-nn = { path = "../candle-nn", features = ["cuda"] } + +[features] +default = [] +cudnn = ["candle/cudnn"] diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index e62f4c32..547e2045 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -33,6 +33,7 @@ criterion = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] +cudnn = ["candle/cudnn"] mkl = ["dep:intel-mkl-src", "candle/mkl"] metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 6589b4b1..fe0beefb 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -29,6 +29,7 @@ tracing = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] +cudnn = ["candle/cudnn", "candle-nn/cudnn"] flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] metal = ["candle/metal", "candle-nn/metal"]