Add the mistral example. (#984)

* Add the mistral example.

* Use the two model files.

* Adjust the dtype.

* Tweak the weight paths.

* Remove the end of text token.

* Get the mistral model to generate some text.
This commit is contained in:
Laurent Mazare
2023-09-28 17:19:18 +02:00
committed by GitHub
parent c05a348e36
commit ada8851a23
2 changed files with 240 additions and 11 deletions

View File

@ -1,7 +1,6 @@
#![allow(unused)]
use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear};
use crate::models::with_tracing::{linear_no_bias, Linear};
/// Mistral LLM, https://github.com/mistralai/mistral-src
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
@ -99,7 +98,7 @@ impl RotaryEmbedding {
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (b_sz, seq_len, h, n_embd) = q.dims4()?;
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
@ -240,7 +239,7 @@ impl Attention {
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => (attn_weights + mask)?,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&value_states)?;
@ -290,7 +289,7 @@ impl DecoderLayer {
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
Ok(xs)
residual + xs
}
}
@ -300,22 +299,24 @@ pub struct Model {
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
#[allow(unused)]
sliding_window: usize,
device: Device,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb.device())?);
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb.pp("layers");
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
@ -359,6 +360,8 @@ impl Model {
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.apply(&self.norm)?.apply(&self.lm_head)
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
}