mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab. * More flash attn. * Set up the flash attn parameters. * Get things to compile locally. * Move the flash attention files in a different directory. * Build the static C library with nvcc. * Add more flash attention. * Update the build part. * Better caching. * Exclude flash attention from the default workspace. * Put flash-attn behind a feature gate. * Get the flash attn kernel to run. * Move the flags to a more appropriate place. * Enable flash attention in llama. * Use flash attention in llama.
This commit is contained in:
@ -14,6 +14,7 @@ readme = "README.md"
|
||||
candle = { path = "../candle-core" }
|
||||
candle-nn = { path = "../candle-nn" }
|
||||
candle-transformers = { path = "../candle-transformers" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", optional = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
@ -37,4 +38,5 @@ anyhow = { workspace = true }
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
|
Reference in New Issue
Block a user