mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Qwen MoE model. (#1960)
* Qwen MoE model. * Add the MoE model to the example. * Fix the scaling. * Readme updates. * Readme tweaks.
This commit is contained in:
@ -125,6 +125,8 @@ We also provide a some command line based examples using state of the art models
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
|
||||
model.
|
||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||
dedicated submodels for hand-writing and printed recognition.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
@ -206,7 +208,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Qwen1.5.
|
||||
- Qwen1.5, Qwen1.5 MoE.
|
||||
- RWKV v5 and v6.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
|
27
candle-examples/examples/qwen/README.md
Normal file
27
candle-examples/examples/qwen/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# candle-qwen: large language model series from Alibaba Cloud
|
||||
|
||||
Qwen 1.5 is a series of large language models that provide strong performances
|
||||
on English and Chinese.
|
||||
|
||||
- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5.
|
||||
- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub.
|
||||
- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the
|
||||
mixture-of-experts (MoE) variant.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example qwen --release -- --prompt "Hello there "
|
||||
```
|
||||
|
||||
Various model sizes are available via the `--model` argument, including the MoE
|
||||
variant.
|
||||
|
||||
```bash
|
||||
$ cargo run --example qwen --release -- --prompt "Hello there " --model moe-a2.7b --prompt 'def print_prime(n: int): '
|
||||
def print_prime(n: int): # n is the number of primes to be printed
|
||||
for i in range(2, n + 1):
|
||||
if all(i % j != 0 for j in range(2, i)):
|
||||
print(i)
|
||||
```
|
||||
|
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config, Model};
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
|
||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
Base(ModelBase),
|
||||
Moe(ModelMoe),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Moe(ref mut m) => m.forward(xs, s),
|
||||
Self::Base(ref mut m) => m.forward(xs, s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -127,6 +142,8 @@ enum WhichModel {
|
||||
W14b,
|
||||
#[value(name = "72b")]
|
||||
W72b,
|
||||
#[value(name = "moe-a2.7b")]
|
||||
MoeA27b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -224,6 +241,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::W7b => "7B",
|
||||
WhichModel::W14b => "14B",
|
||||
WhichModel::W72b => "72B",
|
||||
WhichModel::MoeA27b => "MoE-A2.7B",
|
||||
};
|
||||
format!("Qwen/Qwen1.5-{size}")
|
||||
}
|
||||
@ -244,7 +262,11 @@ fn main() -> Result<()> {
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
|
||||
WhichModel::W4b
|
||||
| WhichModel::W7b
|
||||
| WhichModel::W14b
|
||||
| WhichModel::W72b
|
||||
| WhichModel::MoeA27b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
@ -254,7 +276,6 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config_file = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
@ -262,7 +283,16 @@ fn main() -> Result<()> {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = match args.model {
|
||||
WhichModel::MoeA27b => {
|
||||
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Moe(ModelMoe::new(&config, vb)?)
|
||||
}
|
||||
_ => {
|
||||
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Base(ModelBase::new(&config, vb)?)
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -40,6 +40,7 @@ pub mod quantized_rwkv_v6;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod qwen2;
|
||||
pub mod qwen2_moe;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod rwkv_v5;
|
||||
|
488
candle-transformers/src/models/qwen2_moe.rs
Normal file
488
candle-transformers/src/models/qwen2_moe.rs
Normal file
@ -0,0 +1,488 @@
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub sliding_window: usize,
|
||||
pub max_window_layers: usize,
|
||||
pub tie_word_embeddings: bool,
|
||||
pub rope_theta: f64,
|
||||
pub rms_norm_eps: f64,
|
||||
pub use_sliding_window: bool,
|
||||
pub hidden_act: Activation,
|
||||
pub decoder_sparse_step: usize,
|
||||
pub moe_intermediate_size: usize,
|
||||
pub shared_expert_intermediate_size: usize,
|
||||
pub num_experts_per_tok: usize,
|
||||
pub num_experts: usize,
|
||||
pub norm_topk_prob: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(intermediate_sz: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = hidden_sz / num_heads;
|
||||
let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size: hidden_sz,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/536ea2aca234fb48c5c69769431d643b0d93b233/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L800
|
||||
#[derive(Debug, Clone)]
|
||||
struct SparseMoeBlock {
|
||||
gate: Linear,
|
||||
experts: Vec<MLP>,
|
||||
shared_expert: MLP,
|
||||
shared_expert_gate: Linear,
|
||||
norm_topk_prob: bool,
|
||||
num_experts_per_tok: usize,
|
||||
}
|
||||
|
||||
impl SparseMoeBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?;
|
||||
let mut experts = Vec::with_capacity(cfg.num_experts);
|
||||
let vb_e = vb.pp("experts");
|
||||
for idx in 0..cfg.num_experts {
|
||||
let expert = MLP::new(cfg.moe_intermediate_size, cfg, vb_e.pp(idx))?;
|
||||
experts.push(expert)
|
||||
}
|
||||
let shared_expert = MLP::new(
|
||||
cfg.shared_expert_intermediate_size,
|
||||
cfg,
|
||||
vb.pp("shared_expert"),
|
||||
)?;
|
||||
let shared_expert_gate = linear_no_bias(cfg.hidden_size, 1, vb.pp("shared_expert_gate"))?;
|
||||
Ok(Self {
|
||||
gate,
|
||||
experts,
|
||||
shared_expert,
|
||||
shared_expert_gate,
|
||||
norm_topk_prob: cfg.norm_topk_prob,
|
||||
num_experts_per_tok: cfg.num_experts_per_tok,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SparseMoeBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
|
||||
let xs = xs.reshape(((), hidden_dim))?;
|
||||
let router_logits = xs.apply(&self.gate)?;
|
||||
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
|
||||
|
||||
// In order to extract topk, we extract the data from the tensor and manipulate it
|
||||
// directly. Maybe we will want to use some custom ops instead at some point.
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
|
||||
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
// top_x contains the row indexes to evaluate for each expert.
|
||||
let mut top_x = vec![vec![]; self.experts.len()];
|
||||
let mut selected_experts = vec![vec![]; self.experts.len()];
|
||||
for (row_idx, rw) in routing_weights.iter().enumerate() {
|
||||
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
|
||||
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
|
||||
let mut sum_routing_weights = 0f32;
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = rw[expert_idx];
|
||||
sum_routing_weights += routing_weight;
|
||||
top_x[expert_idx].push(row_idx as u32);
|
||||
}
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = if self.norm_topk_prob {
|
||||
rw[expert_idx] / sum_routing_weights
|
||||
} else {
|
||||
rw[expert_idx]
|
||||
};
|
||||
selected_experts[expert_idx].push(routing_weight)
|
||||
}
|
||||
}
|
||||
|
||||
let mut ys = xs.zeros_like()?;
|
||||
for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
|
||||
let top_x = &top_x[expert_idx];
|
||||
if top_x.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
|
||||
let selected_experts =
|
||||
Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())?
|
||||
.reshape(((), 1))?
|
||||
.to_dtype(xs.dtype())?;
|
||||
// Index the correct hidden states and compute the expert hidden state for
|
||||
// the current expert. We need to make sure to multiply the output hidden
|
||||
// states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
|
||||
// current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
|
||||
let current_hidden_states = expert_layer.forward(¤t_state)?;
|
||||
let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?;
|
||||
ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
|
||||
}
|
||||
let shared_expert_output = xs.apply(&self.shared_expert)?;
|
||||
let shared_expert_output = shared_expert_output.broadcast_mul(&candle_nn::ops::sigmoid(
|
||||
&xs.apply(&self.shared_expert_gate)?,
|
||||
)?)?;
|
||||
let ys = (ys + shared_expert_output)?;
|
||||
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum MlpOrMoeBlock {
|
||||
Mlp(MLP),
|
||||
MoeBlock(SparseMoeBlock),
|
||||
}
|
||||
|
||||
impl Module for MlpOrMoeBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::MoeBlock(m) => m.forward(xs),
|
||||
Self::Mlp(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MlpOrMoeBlock,
|
||||
input_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
layer_idx: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0 {
|
||||
MlpOrMoeBlock::MoeBlock(SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
|
||||
} else {
|
||||
MlpOrMoeBlock::Mlp(MLP::new(cfg.intermediate_size, cfg, vb.pp("mlp"))?)
|
||||
};
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(layer_idx, rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
sliding_window: cfg.sliding_window,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
// Sliding window mask?
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + self.sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user