Support alternative dtypes for mamba (#2036)

* Allow different dtypes in mamba.

* Add a dtype flag.
This commit is contained in:
Laurent Mazare
2024-04-10 18:10:01 +02:00
committed by GitHub
parent a4d5a414e3
commit b81ecf712d
5 changed files with 24 additions and 11 deletions

View File

@ -179,7 +179,9 @@ impl FalconRotaryEmbedding {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let on_true = Tensor::new(on_true, on_false.device())?
.to_dtype(on_false.dtype())?
.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

View File

@ -1,4 +1,3 @@
#![allow(unused)]
/// A fast implementation of mamba for inference only.
/// This is based on: https://github.com/LaurentMazare/mamba.rs
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
@ -38,12 +37,12 @@ pub struct State {
}
impl State {
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
pub fn new(batch_size: usize, cfg: &Config, dtype: DType, device: &Device) -> Result<Self> {
let mut hs = Vec::with_capacity(cfg.n_layer);
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
for _i in 0..cfg.n_layer {
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), dtype, device)?;
let x = Tensor::zeros((batch_size, cfg.d_inner()), dtype, device)?;
hs.push(h);
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
}
@ -128,8 +127,8 @@ impl MambaBlock {
let delta = delta.apply(&self.dt_proj)?;
// softplus
let delta = (delta.exp()? + 1.)?.log()?;
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
let d = self.d.to_dtype(candle::DType::F32)?;
let a = self.a_log.to_dtype(delta.dtype())?.exp()?.neg()?;
let d = self.d.to_dtype(delta.dtype())?;
// Selective scan part
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
@ -178,6 +177,7 @@ pub struct Model {
layers: Vec<ResidualBlock>,
norm_f: RmsNorm,
lm_head: Linear,
dtype: DType,
}
impl Model {
@ -196,6 +196,7 @@ impl Model {
layers,
norm_f,
lm_head,
dtype: vb.dtype(),
})
}
@ -208,4 +209,8 @@ impl Model {
state.pos += 1;
xs.apply(&self.norm_f)?.apply(&self.lm_head)
}
pub fn dtype(&self) -> DType {
self.dtype
}
}