mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Olmo 2 model (#2954)
* OLMo 2 model * Update olmo-2 to example * Clippy fix. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -3,7 +3,7 @@
|
|||||||
OLMo is a series of Open Language Models designed to enable the science of language models.
|
OLMo is a series of Open Language Models designed to enable the science of language models.
|
||||||
|
|
||||||
- **Project Page:** https://allenai.org/olmo
|
- **Project Page:** https://allenai.org/olmo
|
||||||
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
|
- **Papers:** [OLMo](https://arxiv.org/abs/2402.00838) [OLMo 2](https://arxiv.org/abs/2501.00656)
|
||||||
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
|
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
|
||||||
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
|
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
|
||||||
<!-- - **Press release:** TODO -->
|
<!-- - **Press release:** TODO -->
|
||||||
|
@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
|
|||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle_transformers::models::olmo::{Config, Model as OLMo};
|
use candle_transformers::models::olmo::{Config, Model as OLMo};
|
||||||
|
use candle_transformers::models::olmo2::{Config as Config2, Model as OLMo2};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
OLMo(OLMo),
|
OLMo(OLMo),
|
||||||
|
OLMo2(OLMo2),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
@ -82,6 +84,7 @@ impl TextGeneration {
|
|||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
let logits = match &mut self.model {
|
let logits = match &mut self.model {
|
||||||
Model::OLMo(m) => m.forward(&input, start_pos)?,
|
Model::OLMo(m) => m.forward(&input, start_pos)?,
|
||||||
|
Model::OLMo2(m) => m.forward(&input, start_pos)?,
|
||||||
};
|
};
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
let logits = if self.repeat_penalty == 1. {
|
let logits = if self.repeat_penalty == 1. {
|
||||||
@ -129,6 +132,8 @@ enum Which {
|
|||||||
W7bTwin2T,
|
W7bTwin2T,
|
||||||
#[value(name = "1.7-7b")]
|
#[value(name = "1.7-7b")]
|
||||||
V1_7W7b,
|
V1_7W7b,
|
||||||
|
#[value(name = "2-1b")]
|
||||||
|
V2W1b,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -220,6 +225,7 @@ fn main() -> Result<()> {
|
|||||||
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
|
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
|
||||||
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
|
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
|
||||||
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
|
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
|
||||||
|
Which::V2W1b => "allenai/OLMo-2-0425-1B-Instruct".to_string(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -238,33 +244,36 @@ fn main() -> Result<()> {
|
|||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
Which::W1b => {
|
Which::W1b | Which::V2W1b => {
|
||||||
vec![repo.get("model.safetensors")?]
|
vec![repo.get("model.safetensors")?]
|
||||||
}
|
}
|
||||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let config_filename = repo.get("config.json")?;
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = {
|
|
||||||
let config_filename = repo.get("config.json")?;
|
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
|
||||||
config
|
|
||||||
};
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let model = {
|
let dtype = if device.is_cuda() {
|
||||||
let dtype = if device.is_cuda() {
|
DType::BF16
|
||||||
DType::BF16
|
} else {
|
||||||
} else {
|
DType::F32
|
||||||
DType::F32
|
};
|
||||||
};
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let model = match args.model {
|
||||||
let model = OLMo::new(&config, vb)?;
|
Which::W1b | Which::W7b | Which::W7bTwin2T | Which::V1_7W7b => {
|
||||||
Model::OLMo(model)
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let model = OLMo::new(&config, vb)?;
|
||||||
|
Model::OLMo(model)
|
||||||
|
}
|
||||||
|
Which::V2W1b => {
|
||||||
|
let config: Config2 = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let model = OLMo2::new(&config, vb)?;
|
||||||
|
Model::OLMo2(model)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
@ -70,6 +70,7 @@ pub mod moondream;
|
|||||||
pub mod mpt;
|
pub mod mpt;
|
||||||
pub mod nvembed_v2;
|
pub mod nvembed_v2;
|
||||||
pub mod olmo;
|
pub mod olmo;
|
||||||
|
pub mod olmo2;
|
||||||
pub mod openclip;
|
pub mod openclip;
|
||||||
pub mod paligemma;
|
pub mod paligemma;
|
||||||
pub mod parler_tts;
|
pub mod parler_tts;
|
||||||
|
348
candle-transformers/src/models/olmo2.rs
Normal file
348
candle-transformers/src/models/olmo2.rs
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
//! OLMo 2 (Open Language Model) implementation
|
||||||
|
//!
|
||||||
|
//! See OLMo 2 model details at:
|
||||||
|
//! - [Hugging Face Collection](https://huggingface.co/collections/allenai/olmo-2-674117b93ab84e98afc72edc)
|
||||||
|
//! - [OLMo 2 Paper](https://arxiv.org/abs/2501.00656)
|
||||||
|
//!
|
||||||
|
//!
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{linear_b, linear_no_bias, rms_norm, Activation, Linear, RmsNorm, VarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub attention_bias: bool,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
|
pub hidden_act: candle_nn::Activation,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub rope_theta: f64,
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
|
pub clip_qkv: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?,
|
||||||
|
cos: freqs.cos()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MLP {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
act_fn: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MLP {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_sz = cfg.hidden_size;
|
||||||
|
let intermediate_sz = cfg.intermediate_size;
|
||||||
|
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||||
|
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||||
|
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj,
|
||||||
|
up_proj,
|
||||||
|
down_proj,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MLP {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = xs.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
q_norm: RmsNorm,
|
||||||
|
k_norm: RmsNorm,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_sz = cfg.hidden_size;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
|
let head_dim = hidden_sz / num_heads;
|
||||||
|
let b = cfg.attention_bias;
|
||||||
|
let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
|
||||||
|
let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
|
||||||
|
let q_norm = rms_norm(hidden_sz, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||||
|
let k_norm = rms_norm(num_kv_heads * head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
hidden_size: hidden_sz,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b_sz, q_len, _) = xs.dims3()?;
|
||||||
|
|
||||||
|
let query_states = self.q_proj.forward(xs)?;
|
||||||
|
let key_states = self.k_proj.forward(xs)?;
|
||||||
|
let value_states = self.v_proj.forward(xs)?;
|
||||||
|
|
||||||
|
let query_states = self.q_norm.forward(&query_states)?;
|
||||||
|
let key_states = self.k_norm.forward(&key_states)?;
|
||||||
|
|
||||||
|
let query_states = query_states
|
||||||
|
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let key_states = key_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let value_states = value_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let (query_states, key_states) =
|
||||||
|
self.rotary_emb
|
||||||
|
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||||
|
|
||||||
|
let (key_states, value_states) = match &self.kv_cache {
|
||||||
|
None => (key_states, value_states),
|
||||||
|
Some((prev_k, prev_v)) => {
|
||||||
|
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||||
|
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
let value_states =
|
||||||
|
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
|
let attn_output = {
|
||||||
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
|
let attn_weights = match attention_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||||
|
};
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
attn_weights.matmul(&value_states)?
|
||||||
|
};
|
||||||
|
attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_sz, q_len, self.hidden_size))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DecoderLayer {
|
||||||
|
self_attn: Attention,
|
||||||
|
mlp: MLP,
|
||||||
|
post_attention_layernorm: RmsNorm,
|
||||||
|
post_feedforward_layernorm: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecoderLayer {
|
||||||
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||||
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
let post_feedforward_layernorm = rms_norm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
vb.pp("post_feedforward_layernorm"),
|
||||||
|
)?;
|
||||||
|
let post_attention_layernorm = rms_norm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
vb.pp("post_attention_layernorm"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
post_attention_layernorm,
|
||||||
|
post_feedforward_layernorm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = self.self_attn.forward(xs, attention_mask, seqlen_offset)?;
|
||||||
|
let xs = self.post_attention_layernorm.forward(&xs)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = self.mlp.forward(&xs)?;
|
||||||
|
let xs = self.post_feedforward_layernorm.forward(&xs)?;
|
||||||
|
residual + xs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embed_tokens: candle_nn::Embedding,
|
||||||
|
layers: Vec<DecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
lm_head: Linear,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_m = vb.pp("model");
|
||||||
|
let embed_tokens =
|
||||||
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
|
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_l = vb_m.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
|
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||||
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::new(embed_tokens.embeddings().clone(), None)
|
||||||
|
} else {
|
||||||
|
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
lm_head,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_decoder_attention_mask(
|
||||||
|
&self,
|
||||||
|
b_size: usize,
|
||||||
|
tgt_len: usize,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
// Sliding window mask?
|
||||||
|
let mask: Vec<_> = (0..tgt_len)
|
||||||
|
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||||
|
let mask = if seqlen_offset > 0 {
|
||||||
|
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
|
||||||
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
|
.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
|
let (b_size, seq_len) = input_ids.dims2()?;
|
||||||
|
let attention_mask = if seq_len <= 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||||
|
Some(mask)
|
||||||
|
};
|
||||||
|
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||||
|
}
|
||||||
|
xs.narrow(1, seq_len - 1, 1)?
|
||||||
|
.apply(&self.norm)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
layer.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user