Mistral gpu fix (#985)

* Add the mistral example.

* Use the two model files.

* Adjust the dtype.

* Tweak the weight paths.

* Remove the end of text token.

* Get the mistral model to generate some text.

* Fix when running on the gpu.

* More gpu fixes.
This commit is contained in:
Laurent Mazare
2023-09-28 17:38:13 +02:00
committed by GitHub
parent ada8851a23
commit 716ab2ccdc

View File

@ -72,7 +72,7 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
}
impl RotaryEmbedding {
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
@ -80,9 +80,9 @@ impl RotaryEmbedding {
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
@ -302,6 +302,7 @@ pub struct Model {
#[allow(unused)]
sliding_window: usize,
device: Device,
dtype: DType,
}
impl Model {
@ -309,7 +310,7 @@ impl Model {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?);
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
@ -325,6 +326,7 @@ impl Model {
lm_head,
sliding_window: cfg.sliding_window,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
@ -345,7 +347,8 @@ impl Model {
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {