mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Add Qwen3 MoE (#2934)
* qwen-moe rebase * lint * fixed rebase error * swapped normal MoE model with CausalMoE Model in example, and swapped the tie word embeddings if statement * updated readme
This commit is contained in:
@ -25,3 +25,28 @@ def print_prime(n: int): # n is the number of primes to be printed
|
|||||||
print(i)
|
print(i)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The qwen3 MoE variant is also an option.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. <think></think>." --model "3-moe-a3b"
|
||||||
|
> In morning's hush, where daisies sleep,
|
||||||
|
> A fleeting dance through sunlit deep—
|
||||||
|
> They flutter soft on gossamer thread,
|
||||||
|
> The messengers of spring’s own head.
|
||||||
|
>
|
||||||
|
> With painted sails and delicate grace,
|
||||||
|
> They drift from bloom to blossom's face.
|
||||||
|
> Each wing a tale in hues unseen,
|
||||||
|
> Of ancient dreams and secrets between.
|
||||||
|
>
|
||||||
|
> No sound they make, yet still they speak—
|
||||||
|
> Of time that flies, of life so brief.
|
||||||
|
> A fleeting kiss on summer’s breath,
|
||||||
|
> A whisper lost before death.
|
||||||
|
>
|
||||||
|
> Yet in their flight, the soul takes wing,
|
||||||
|
> And for a moment, all is spring.
|
||||||
|
> For though they fade, they never die—
|
||||||
|
> Their beauty lives where hearts can fly.
|
||||||
|
> 161 tokens generated (3.00 token/s)
|
||||||
|
```
|
||||||
|
@ -10,6 +10,7 @@ use clap::Parser;
|
|||||||
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
||||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||||
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
|
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
|
||||||
|
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
@ -22,6 +23,7 @@ enum Model {
|
|||||||
Base(ModelBase),
|
Base(ModelBase),
|
||||||
Moe(ModelMoe),
|
Moe(ModelMoe),
|
||||||
Base3(Model3),
|
Base3(Model3),
|
||||||
|
Moe3(ModelMoe3),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -30,6 +32,7 @@ impl Model {
|
|||||||
Self::Moe(ref mut m) => m.forward(xs, s),
|
Self::Moe(ref mut m) => m.forward(xs, s),
|
||||||
Self::Base(ref mut m) => m.forward(xs, s),
|
Self::Base(ref mut m) => m.forward(xs, s),
|
||||||
Self::Base3(ref mut m) => m.forward(xs, s),
|
Self::Base3(ref mut m) => m.forward(xs, s),
|
||||||
|
Self::Moe3(ref mut m) => m.forward(xs, s),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -167,6 +170,8 @@ enum WhichModel {
|
|||||||
W3_4b,
|
W3_4b,
|
||||||
#[value(name = "3-8b")]
|
#[value(name = "3-8b")]
|
||||||
W3_8b,
|
W3_8b,
|
||||||
|
#[value(name = "3-moe-a3b")]
|
||||||
|
W3MoeA3b,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -273,6 +278,7 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::W3_1_7b => ("3", "1.7B"),
|
WhichModel::W3_1_7b => ("3", "1.7B"),
|
||||||
WhichModel::W3_4b => ("3", "4B"),
|
WhichModel::W3_4b => ("3", "4B"),
|
||||||
WhichModel::W3_8b => ("3", "8B"),
|
WhichModel::W3_8b => ("3", "8B"),
|
||||||
|
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
|
||||||
};
|
};
|
||||||
format!("Qwen/Qwen{version}-{size}")
|
format!("Qwen/Qwen{version}-{size}")
|
||||||
}
|
}
|
||||||
@ -308,7 +314,8 @@ fn main() -> Result<()> {
|
|||||||
| WhichModel::MoeA27b
|
| WhichModel::MoeA27b
|
||||||
| WhichModel::W3_1_7b
|
| WhichModel::W3_1_7b
|
||||||
| WhichModel::W3_4b
|
| WhichModel::W3_4b
|
||||||
| WhichModel::W3_8b => {
|
| WhichModel::W3_8b
|
||||||
|
| WhichModel::W3MoeA3b => {
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -334,6 +341,10 @@ fn main() -> Result<()> {
|
|||||||
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
Model::Base3(Model3::new(&config, vb)?)
|
Model::Base3(Model3::new(&config, vb)?)
|
||||||
}
|
}
|
||||||
|
WhichModel::W3MoeA3b => {
|
||||||
|
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
|
Model::Moe3(ModelMoe3::new(&config, vb)?)
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
Model::Base(ModelBase::new(&config, vb)?)
|
Model::Base(ModelBase::new(&config, vb)?)
|
||||||
|
@ -100,6 +100,7 @@ pub mod quantized_t5;
|
|||||||
pub mod qwen2;
|
pub mod qwen2;
|
||||||
pub mod qwen2_moe;
|
pub mod qwen2_moe;
|
||||||
pub mod qwen3;
|
pub mod qwen3;
|
||||||
|
pub mod qwen3_moe;
|
||||||
pub mod recurrent_gemma;
|
pub mod recurrent_gemma;
|
||||||
pub mod repvgg;
|
pub mod repvgg;
|
||||||
pub mod resnet;
|
pub mod resnet;
|
||||||
|
355
candle-transformers/src/models/qwen3_moe.rs
Normal file
355
candle-transformers/src/models/qwen3_moe.rs
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
use crate::models::{
|
||||||
|
qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding},
|
||||||
|
with_tracing::{linear_no_bias, Linear, RmsNorm},
|
||||||
|
};
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{Activation, VarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub head_dim: usize,
|
||||||
|
pub attention_bias: bool,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub sliding_window: Option<usize>,
|
||||||
|
pub max_window_layers: usize,
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
|
pub rope_theta: f64,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
|
pub use_sliding_window: bool,
|
||||||
|
pub hidden_act: Activation,
|
||||||
|
// MoE specific configuration
|
||||||
|
pub decoder_sparse_step: usize,
|
||||||
|
pub moe_intermediate_size: usize,
|
||||||
|
pub num_experts_per_tok: usize,
|
||||||
|
pub num_experts: usize,
|
||||||
|
pub norm_topk_prob: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&Config> for Qwen3Config {
|
||||||
|
fn from(val: &Config) -> Self {
|
||||||
|
Qwen3Config {
|
||||||
|
vocab_size: val.vocab_size,
|
||||||
|
hidden_size: val.hidden_size,
|
||||||
|
intermediate_size: val.intermediate_size,
|
||||||
|
num_hidden_layers: val.num_hidden_layers,
|
||||||
|
num_attention_heads: val.num_attention_heads,
|
||||||
|
head_dim: val.head_dim,
|
||||||
|
attention_bias: val.attention_bias,
|
||||||
|
num_key_value_heads: val.num_key_value_heads,
|
||||||
|
max_position_embeddings: val.max_position_embeddings,
|
||||||
|
sliding_window: val.sliding_window,
|
||||||
|
max_window_layers: val.max_window_layers,
|
||||||
|
tie_word_embeddings: val.tie_word_embeddings,
|
||||||
|
rope_theta: val.rope_theta,
|
||||||
|
rms_norm_eps: val.rms_norm_eps,
|
||||||
|
use_sliding_window: val.use_sliding_window,
|
||||||
|
hidden_act: val.hidden_act,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Qwen3MLPExpert {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
act_fn: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3MLPExpert {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj: linear_no_bias(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.moe_intermediate_size,
|
||||||
|
vb.pp("gate_proj"),
|
||||||
|
)?,
|
||||||
|
up_proj: linear_no_bias(cfg.hidden_size, cfg.moe_intermediate_size, vb.pp("up_proj"))?,
|
||||||
|
down_proj: linear_no_bias(
|
||||||
|
cfg.moe_intermediate_size,
|
||||||
|
cfg.hidden_size,
|
||||||
|
vb.pp("down_proj"),
|
||||||
|
)?,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3MLPExpert {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = x.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Qwen3 Sparse MoE Block implementation
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Qwen3SparseMoeBlock {
|
||||||
|
gate: Linear,
|
||||||
|
experts: Vec<Qwen3MLPExpert>,
|
||||||
|
norm_topk_prob: bool,
|
||||||
|
num_experts_per_tok: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3SparseMoeBlock {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?;
|
||||||
|
let mut experts = Vec::with_capacity(cfg.num_experts);
|
||||||
|
let vb_e = vb.pp("experts");
|
||||||
|
for idx in 0..cfg.num_experts {
|
||||||
|
let expert = Qwen3MLPExpert::new(cfg, vb_e.pp(idx))?;
|
||||||
|
experts.push(expert)
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
gate,
|
||||||
|
experts,
|
||||||
|
norm_topk_prob: cfg.norm_topk_prob,
|
||||||
|
num_experts_per_tok: cfg.num_experts_per_tok,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3SparseMoeBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
|
||||||
|
let xs = xs.reshape(((), hidden_dim))?;
|
||||||
|
let router_logits = xs.apply(&self.gate)?;
|
||||||
|
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
|
||||||
|
|
||||||
|
// Extract topk experts per token
|
||||||
|
let experts_per_tok = routing_weights
|
||||||
|
.arg_sort_last_dim(false)?
|
||||||
|
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
|
||||||
|
.contiguous()?;
|
||||||
|
let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
|
||||||
|
|
||||||
|
// Extract needed data
|
||||||
|
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||||
|
let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
|
||||||
|
let mut top_x = vec![vec![]; self.experts.len()];
|
||||||
|
let mut selected_experts = vec![vec![]; self.experts.len()];
|
||||||
|
for (row_idx, (rw, expert_idxs)) in routing_weights
|
||||||
|
.iter()
|
||||||
|
.zip(experts_per_tok.iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
let sum_rw = rw.iter().sum::<f32>();
|
||||||
|
for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
|
||||||
|
top_x[expert_idx as usize].push(row_idx as u32);
|
||||||
|
let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
|
||||||
|
selected_experts[expert_idx as usize].push(rw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process through experts
|
||||||
|
let mut ys = xs.zeros_like()?;
|
||||||
|
for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
|
||||||
|
let top_x = &top_x[expert_idx];
|
||||||
|
if top_x.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
|
||||||
|
let selected_experts =
|
||||||
|
Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())?
|
||||||
|
.reshape(((), 1))?
|
||||||
|
.to_dtype(xs.dtype())?;
|
||||||
|
|
||||||
|
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
|
||||||
|
let current_hidden_states = expert_layer.forward(¤t_state)?;
|
||||||
|
let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?;
|
||||||
|
ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
ys.reshape((b_size, seq_len, hidden_dim))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MLP or MoE decision enum
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Qwen3FeedForward {
|
||||||
|
Mlp(Qwen3MLP),
|
||||||
|
MoE(Qwen3SparseMoeBlock),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3FeedForward {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Mlp(m) => m.forward(xs),
|
||||||
|
Self::MoE(m) => m.forward(xs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DecoderLayer {
|
||||||
|
self_attn: Qwen3Attention,
|
||||||
|
feed_forward: Qwen3FeedForward,
|
||||||
|
ln1: RmsNorm,
|
||||||
|
ln2: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecoderLayer {
|
||||||
|
fn new(
|
||||||
|
layer_idx: usize,
|
||||||
|
cfg: &Config,
|
||||||
|
rotary: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?;
|
||||||
|
|
||||||
|
// Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step
|
||||||
|
let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0
|
||||||
|
{
|
||||||
|
Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
|
||||||
|
} else {
|
||||||
|
Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?)
|
||||||
|
};
|
||||||
|
|
||||||
|
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
|
let ln2 = RmsNorm::new(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
vb.pp("post_attention_layernorm"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
feed_forward,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
|
||||||
|
let h = self.ln1.forward(x)?;
|
||||||
|
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||||
|
let x = (x + h)?;
|
||||||
|
let h2 = self.ln2.forward(&x)?;
|
||||||
|
let h2 = h2.apply(&self.feed_forward)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embed_tokens: candle_nn::Embedding,
|
||||||
|
layers: Vec<DecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embed_tokens =
|
||||||
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||||
|
let rotary = Arc::new(Qwen3RotaryEmbedding::new(
|
||||||
|
vb.dtype(),
|
||||||
|
&cfg.into(),
|
||||||
|
vb.device(),
|
||||||
|
)?);
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_l = vb.pp("model.layers");
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(DecoderLayer::new(i, cfg, rotary.clone(), vb_l.pp(i))?);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
for l in &mut self.layers {
|
||||||
|
l.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(
|
||||||
|
&self,
|
||||||
|
b: usize,
|
||||||
|
tgt: usize,
|
||||||
|
offset: usize,
|
||||||
|
sw: Option<usize>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| {
|
||||||
|
(0..(tgt + offset)).map(move |j| {
|
||||||
|
let past_ok = j <= i + offset;
|
||||||
|
let sw_ok = match sw {
|
||||||
|
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
|
||||||
|
None => true,
|
||||||
|
};
|
||||||
|
if past_ok && sw_ok {
|
||||||
|
0.
|
||||||
|
} else {
|
||||||
|
minf
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset, None)?)
|
||||||
|
};
|
||||||
|
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ModelForCausalLM {
|
||||||
|
base: Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelForCausalLM {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let base = Model::new(cfg, vb.clone())?;
|
||||||
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::from_weights(base.embed_tokens.embeddings().clone(), None)
|
||||||
|
} else {
|
||||||
|
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
|
};
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
self.base
|
||||||
|
.forward(input, offset)?
|
||||||
|
.narrow(1, l - 1, 1)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user