Commit Graph

6 Commits

Author SHA1 Message Date
1235aa2536 Use bail rather than wrapping a string where possible. (#249)
* Use bail rather than wrapping a string where possible.

* Revert the cuda default bit.
2023-07-26 15:42:46 +01:00
f052ba76cb Lining up the flash attn version with the non-flash one. (#248)
* Move the flash-attn function in the proper crate.

* Causality tweak.
2023-07-26 15:11:45 +01:00
2ce5f12513 Again set a few extra params in flash-attn. (#245)
* Again set a few extra params.

* Use the appropriate kernel sizes.

* Add all the kernel sizes.

* Parallel compiling.

* Reduce the amount of parallelism.

* Add the missing kernel.

* Fix a typo.

* Remove bf16 support for now.
2023-07-26 14:16:37 +01:00
fa2b64d678 Proper flash-attn parameters. (#244)
* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
2023-07-26 10:13:40 +01:00
471855e2ee Specific cache dir for the flash attn build artifacts. (#242) 2023-07-26 08:04:02 +01:00
d9f9c859af Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
2023-07-26 07:48:10 +01:00