Add some flash attn test (#253)

* Add some flash-attn test.

* Add the cpu test.

* Fail when the head is not a multiple of 8.

* Polish the flash attention test.
This commit is contained in:
Laurent Mazare
2023-07-26 20:56:00 +01:00
committed by GitHub
parent ded197497c
commit 4f92420132
5 changed files with 125 additions and 14 deletions

View File

@ -18,3 +18,6 @@ half = { version = "2.3.1", features = ["num-traits"] }
anyhow = { version = "1", features = ["backtrace"] }
num_cpus = "1.15.0"
rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }