mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Mistral gpu fix (#985)
* Add the mistral example. * Use the two model files. * Adjust the dtype. * Tweak the weight paths. * Remove the end of text token. * Get the mistral model to generate some text. * Fix when running on the gpu. * More gpu fixes.
This commit is contained in:
@ -72,7 +72,7 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
let max_seq_len = cfg.max_position_embeddings;
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
@ -80,9 +80,9 @@ impl RotaryEmbedding {
|
|||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let inv_freq_len = inv_freq.len();
|
let inv_freq_len = inv_freq.len();
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(dtype)?
|
||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||||
@ -302,6 +302,7 @@ pub struct Model {
|
|||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
sliding_window: usize,
|
sliding_window: usize,
|
||||||
device: Device,
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -309,7 +310,7 @@ impl Model {
|
|||||||
let vb_m = vb.pp("model");
|
let vb_m = vb.pp("model");
|
||||||
let embed_tokens =
|
let embed_tokens =
|
||||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?);
|
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
let vb_l = vb_m.pp("layers");
|
let vb_l = vb_m.pp("layers");
|
||||||
for layer_idx in 0..cfg.num_hidden_layers {
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
@ -325,6 +326,7 @@ impl Model {
|
|||||||
lm_head,
|
lm_head,
|
||||||
sliding_window: cfg.sliding_window,
|
sliding_window: cfg.sliding_window,
|
||||||
device: vb.device().clone(),
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -345,7 +347,8 @@ impl Model {
|
|||||||
} else {
|
} else {
|
||||||
mask
|
mask
|
||||||
};
|
};
|
||||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))
|
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
|
.to_dtype(self.dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
|
Reference in New Issue
Block a user