mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Mistral gpu fix (#985)
* Add the mistral example. * Use the two model files. * Adjust the dtype. * Tweak the weight paths. * Remove the end of text token. * Get the mistral model to generate some text. * Fix when running on the gpu. * More gpu fixes.
This commit is contained in:
@ -72,7 +72,7 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
@ -80,9 +80,9 @@ impl RotaryEmbedding {
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
@ -302,6 +302,7 @@ pub struct Model {
|
||||
#[allow(unused)]
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -309,7 +310,7 @@ impl Model {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?);
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
@ -325,6 +326,7 @@ impl Model {
|
||||
lm_head,
|
||||
sliding_window: cfg.sliding_window,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -345,7 +347,8 @@ impl Model {
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user