mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the Mixtral model. (#1437)
* Add the Mixtral model. * Add more of the mixtral layers. * Add the final layers for mixtral. * Sketch the expert selection. * Add some expert routing logic. * Hopefully finish the routing logic for mixtral. * Add the mixtral example. * Fix the weight filenames. * Bugfix. * Another fix. * Yet another fix + remove the unused pragma. * Shape fix. * Add a readme.
This commit is contained in:
499
candle-transformers/src/models/mixtral.rs
Normal file
499
candle-transformers/src/models/mixtral.rs
Normal file
@ -0,0 +1,499 @@
|
||||
use crate::models::with_tracing::{linear_no_bias, Linear};
|
||||
/// Mixtral Model
|
||||
/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
|
||||
/// https://mistral.ai/news/mixtral-of-experts/
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
pub(crate) hidden_size: usize,
|
||||
pub(crate) intermediate_size: usize,
|
||||
pub(crate) num_hidden_layers: usize,
|
||||
pub(crate) num_attention_heads: usize,
|
||||
pub(crate) num_key_value_heads: usize,
|
||||
pub(crate) hidden_act: Activation,
|
||||
pub(crate) max_position_embeddings: usize,
|
||||
pub(crate) rms_norm_eps: f64,
|
||||
pub(crate) rope_theta: f64,
|
||||
pub(crate) sliding_window: usize,
|
||||
pub(crate) num_experts_per_tok: usize,
|
||||
pub(crate) num_local_experts: usize,
|
||||
pub(crate) use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
|
||||
pub fn v0_1_8x7b(use_flash_attn: bool) -> Self {
|
||||
Self {
|
||||
vocab_size: 32000,
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 14336,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 8,
|
||||
hidden_act: Activation::Silu,
|
||||
max_position_embeddings: 32768,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 1e6,
|
||||
sliding_window: 4096,
|
||||
num_experts_per_tok: 2,
|
||||
num_local_experts: 8,
|
||||
use_flash_attn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / (cfg.rope_theta as f32).powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = hidden_sz / num_heads;
|
||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size: hidden_sz,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = query_states.transpose(1, 2)?;
|
||||
let k = key_states.transpose(1, 2)?;
|
||||
let v = value_states.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
|
||||
} else {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BlockSparseTop2MLP {
|
||||
w1: Linear,
|
||||
w2: Linear,
|
||||
w3: Linear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl BlockSparseTop2MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?;
|
||||
let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?;
|
||||
let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?;
|
||||
Ok(Self {
|
||||
w1,
|
||||
w2,
|
||||
w3,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BlockSparseTop2MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.w1)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.w3)?;
|
||||
(lhs * rhs)?.apply(&self.w2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SparseMoeBlock {
|
||||
gate: Linear,
|
||||
experts: Vec<BlockSparseTop2MLP>,
|
||||
num_experts_per_tok: usize,
|
||||
}
|
||||
|
||||
impl SparseMoeBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?;
|
||||
let mut experts = Vec::with_capacity(cfg.num_local_experts);
|
||||
let vb = vb.pp("experts");
|
||||
for idx in 0..cfg.num_local_experts {
|
||||
let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx))?;
|
||||
experts.push(expert)
|
||||
}
|
||||
Ok(SparseMoeBlock {
|
||||
gate,
|
||||
experts,
|
||||
num_experts_per_tok: cfg.num_experts_per_tok,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SparseMoeBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
|
||||
let xs = xs.reshape(((), hidden_dim))?;
|
||||
let router_logits = xs.apply(&self.gate)?;
|
||||
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
|
||||
|
||||
// In order to extract topk, we extract the data from the tensor and manipulate it
|
||||
// directly. Maybe we will want to use some custom ops instead at some point.
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
|
||||
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
// top_x contains the row indexes to evaluate for each expert.
|
||||
let mut top_x = vec![vec![]; self.experts.len()];
|
||||
let mut selected_rws = vec![vec![]; self.experts.len()];
|
||||
for (row_idx, rw) in routing_weights.iter().enumerate() {
|
||||
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
|
||||
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
|
||||
let mut sum_routing_weights = 0f32;
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = rw[expert_idx];
|
||||
sum_routing_weights += routing_weight;
|
||||
top_x[expert_idx].push(row_idx as u32);
|
||||
}
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = rw[expert_idx];
|
||||
selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
|
||||
}
|
||||
}
|
||||
|
||||
// routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
// expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
let mut ys = xs.zeros_like()?;
|
||||
for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
|
||||
let top_x = &top_x[expert_idx];
|
||||
if top_x.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
|
||||
let selected_rws =
|
||||
Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
|
||||
// Index the correct hidden states and compute the expert hidden state for
|
||||
// the current expert. We need to make sure to multiply the output hidden
|
||||
// states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
|
||||
// current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
|
||||
let current_hidden_states = expert_layer.forward(¤t_state)?;
|
||||
let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
|
||||
ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
|
||||
}
|
||||
|
||||
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
block_sparse_moe: SparseMoeBlock,
|
||||
input_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let block_sparse_moe = SparseMoeBlock::new(cfg, vb.pp("block_sparse_moe"))?;
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
block_sparse_moe,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs
|
||||
.apply(&self.post_attention_layernorm)?
|
||||
.apply(&self.block_sparse_moe)?;
|
||||
residual + xs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
sliding_window: cfg.sliding_window,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
// Sliding window mask?
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + self.sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
}
|
@ -14,6 +14,7 @@ pub mod llama2_c_weights;
|
||||
pub mod marian;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
pub mod mpt;
|
||||
pub mod persimmon;
|
||||
pub mod quantized_blip;
|
||||
|
Reference in New Issue
Block a user