Support for attention bias in gemma + refactor things a bit. (#1744)

* Support for attention bias in gemma + refactor things a bit.

* Fix the cuda tests.
This commit is contained in:
Laurent Mazare
2024-02-22 09:35:28 +01:00
committed by GitHub
parent 8013b50829
commit c753f72c85
8 changed files with 62 additions and 88 deletions

View File

@ -47,6 +47,12 @@ impl Linear {
}
}
pub fn linear_b(d1: usize, d2: usize, b: bool, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear_b(d1, d2, b, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");