Commit Graph

3 Commits

Author SHA1 Message Date
d9f9c859af Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
2023-07-26 07:48:10 +01:00
e449ce53a2 Wrapping code to call the custom op. (#225)
* Wrapping code to call the custom op.

* Get the rms example to work.

* Get around rustfmt failing in the CI.

* Fix the rms computation.
2023-07-23 11:31:17 +01:00
b8a10425ad Kernel build example (#224)
* Build example kernels.

* Add some sample custom kernel.

* Get the example kernel to compile.

* Add some cuda code.

* More cuda custom op.

* More cuda custom ops.
2023-07-23 07:15:37 +01:00