Fix the fast bf16 gemm cublas kernels. (#2274)

* Use flash-attn in gemma.

* Fix for the fast bf16 cublas gemm.

* Fix some clippy lints.

* Fix another lint.

* Proper clippy fix.
This commit is contained in:
Laurent Mazare
2024-06-18 23:46:58 +02:00
committed by GitHub
parent 2b10aaa05d
commit 36cf54525d
5 changed files with 25 additions and 14 deletions

View File

@ -54,8 +54,7 @@ impl ModuleT for Vgg<'_> {
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
let layers = convs
.iter()
.enumerate()
.map(|(_, &(in_c, out_c, name))| {
.map(|&(in_c, out_c, name)| {
candle_nn::conv2d(
in_c,
out_c,