From 716ab2ccdcb07aab26c41a98a839c31ac9760ca6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 28 Sep 2023 17:38:13 +0200 Subject: [PATCH] Mistral gpu fix (#985) * Add the mistral example. * Use the two model files. * Adjust the dtype. * Tweak the weight paths. * Remove the end of text token. * Get the mistral model to generate some text. * Fix when running on the gpu. * More gpu fixes. --- candle-transformers/src/models/mistral.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 7db83ff1..33569bd8 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -72,7 +72,7 @@ fn rotate_half(xs: &Tensor) -> Result { } impl RotaryEmbedding { - fn new(cfg: &Config, dev: &Device) -> Result { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let dim = cfg.hidden_size / cfg.num_attention_heads; let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) @@ -80,9 +80,9 @@ impl RotaryEmbedding { .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(DType::F32)? + .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; @@ -302,6 +302,7 @@ pub struct Model { #[allow(unused)] sliding_window: usize, device: Device, + dtype: DType, } impl Model { @@ -309,7 +310,7 @@ impl Model { let vb_m = vb.pp("model"); let embed_tokens = candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; - let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?); + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { @@ -325,6 +326,7 @@ impl Model { lm_head, sliding_window: cfg.sliding_window, device: vb.device().clone(), + dtype: vb.dtype(), }) } @@ -345,7 +347,8 @@ impl Model { } else { mask }; - mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset)) + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result {