Support alternative dtypes for mamba (#2036)

* Allow different dtypes in mamba.

* Add a dtype flag.
This commit is contained in:
Laurent Mazare
2024-04-10 18:10:01 +02:00
committed by GitHub
parent a4d5a414e3
commit b81ecf712d
5 changed files with 24 additions and 11 deletions

View File

@ -54,6 +54,7 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write; use std::io::Write;
self.tokenizer.clear(); self.tokenizer.clear();
let dtype = self.model.dtype();
let mut tokens = self let mut tokens = self
.tokenizer .tokenizer
.tokenizer() .tokenizer()
@ -66,7 +67,7 @@ impl TextGeneration {
Some(token) => token, Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"), None => anyhow::bail!("cannot find the </s> token"),
}; };
let mut state = State::new(1, &self.config, &self.device)?; let mut state = State::new(1, &self.config, dtype, &self.device)?;
let mut next_logits = None; let mut next_logits = None;
for &t in tokens.iter() { for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?; let input = Tensor::new(&[t], &self.device)?;
@ -84,7 +85,7 @@ impl TextGeneration {
Some(logits) => logits, Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"), None => anyhow::bail!("cannot work on an empty prompt"),
}; };
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; let logits = logits.squeeze(0)?.to_dtype(dtype)?;
let logits = if self.repeat_penalty == 1. { let logits = if self.repeat_penalty == 1. {
logits logits
} else { } else {
@ -210,6 +211,9 @@ struct Args {
#[arg(long)] #[arg(long)]
config_file: Option<String>, config_file: Option<String>,
#[arg(long, default_value = "f32")]
dtype: String,
/// Penalty to be applied for repeating tokens, 1. means no penalty. /// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)] #[arg(long, default_value_t = 1.1)]
repeat_penalty: f32, repeat_penalty: f32,
@ -220,6 +224,7 @@ struct Args {
} }
fn main() -> Result<()> { fn main() -> Result<()> {
use std::str::FromStr;
use tracing_chrome::ChromeLayerBuilder; use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
@ -279,7 +284,8 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let dtype = DType::from_str(&args.dtype)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?; let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

View File

@ -179,7 +179,9 @@ impl FalconRotaryEmbedding {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape(); let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; let on_true = Tensor::new(on_true, on_false.device())?
.to_dtype(on_false.dtype())?
.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?; let m = mask.where_cond(&on_true, on_false)?;
Ok(m) Ok(m)
} }

View File

@ -1,4 +1,3 @@
#![allow(unused)]
/// A fast implementation of mamba for inference only. /// A fast implementation of mamba for inference only.
/// This is based on: https://github.com/LaurentMazare/mamba.rs /// This is based on: https://github.com/LaurentMazare/mamba.rs
use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use crate::models::with_tracing::{linear, linear_no_bias, Linear};
@ -38,12 +37,12 @@ pub struct State {
} }
impl State { impl State {
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> { pub fn new(batch_size: usize, cfg: &Config, dtype: DType, device: &Device) -> Result<Self> {
let mut hs = Vec::with_capacity(cfg.n_layer); let mut hs = Vec::with_capacity(cfg.n_layer);
let mut prev_xs = Vec::with_capacity(cfg.n_layer); let mut prev_xs = Vec::with_capacity(cfg.n_layer);
for _i in 0..cfg.n_layer { for _i in 0..cfg.n_layer {
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?; let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), dtype, device)?;
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?; let x = Tensor::zeros((batch_size, cfg.d_inner()), dtype, device)?;
hs.push(h); hs.push(h);
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]); prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
} }
@ -128,8 +127,8 @@ impl MambaBlock {
let delta = delta.apply(&self.dt_proj)?; let delta = delta.apply(&self.dt_proj)?;
// softplus // softplus
let delta = (delta.exp()? + 1.)?.log()?; let delta = (delta.exp()? + 1.)?.log()?;
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?; let a = self.a_log.to_dtype(delta.dtype())?.exp()?.neg()?;
let d = self.d.to_dtype(candle::DType::F32)?; let d = self.d.to_dtype(delta.dtype())?;
// Selective scan part // Selective scan part
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t // Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
@ -178,6 +177,7 @@ pub struct Model {
layers: Vec<ResidualBlock>, layers: Vec<ResidualBlock>,
norm_f: RmsNorm, norm_f: RmsNorm,
lm_head: Linear, lm_head: Linear,
dtype: DType,
} }
impl Model { impl Model {
@ -196,6 +196,7 @@ impl Model {
layers, layers,
norm_f, norm_f,
lm_head, lm_head,
dtype: vb.dtype(),
}) })
} }
@ -208,4 +209,8 @@ impl Model {
state.pos += 1; state.pos += 1;
xs.apply(&self.norm_f)?.apply(&self.lm_head) xs.apply(&self.norm_f)?.apply(&self.lm_head)
} }
pub fn dtype(&self) -> DType {
self.dtype
}
} }

View File

@ -2,7 +2,7 @@ use candle::{Result, Tensor};
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> { pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
let device = logits.device(); let device = logits.device();
let mut logits = logits.to_vec1::<f32>()?; let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
let mut already_seen = std::collections::HashSet::new(); let mut already_seen = std::collections::HashSet::new();
for token_id in context { for token_id in context {
if already_seen.contains(token_id) { if already_seen.contains(token_id) {