mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Marian MT model (#1210)
* Skeleton files for the marian MT model. * Marian initialization. * Implement the attention forward method. * Forward pass for the encoder side. * Expose the encoder and decoder. * Start plugging the decoder. * Forward pass for the decoder layer. * Set up the marian example. * Add some missing backtraces. * Bugfix.
This commit is contained in:
@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let n_ids = match self.ids_l.dims() {
|
||||
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
for left_i in 0..ids_left_len {
|
||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let dim = self.dim;
|
||||
@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::U32(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::I64(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
|
90
candle-examples/examples/marian-mt/main.rs
Normal file
90
candle-examples/examples/marian-mt/main.rs
Normal file
@ -0,0 +1,90 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::marian;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
const SEP_TOKEN_ID: u32 = 102;
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let config = marian::Config::opus_mt_tc_big_fr_en();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? };
|
||||
let model = marian::MTModel::new(&config, vb)?;
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?;
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone());
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let encoder_xs = {
|
||||
let tokens = tokenizer
|
||||
.encode(args.text, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
model.encoder().forward(&tokens, 0)?
|
||||
};
|
||||
|
||||
let mut token_ids = vec![30522u32];
|
||||
for index in 0..1000 {
|
||||
// TODO: Add a kv cache.
|
||||
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.decode(&input_ids, &encoder_xs)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
if token == SEP_TOKEN_ID {
|
||||
break;
|
||||
}
|
||||
token_ids.push(token);
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
413
candle-transformers/src/models/marian.rs
Normal file
413
candle-transformers/src/models/marian.rs
Normal file
@ -0,0 +1,413 @@
|
||||
#![allow(unused)]
|
||||
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub decoder_vocab_size: Option<usize>,
|
||||
pub max_position_embeddings: usize,
|
||||
pub encoder_layers: usize,
|
||||
pub encoder_ffn_dim: usize,
|
||||
pub encoder_attention_heads: usize,
|
||||
pub decoder_layers: usize,
|
||||
pub decoder_ffn_dim: usize,
|
||||
pub decoder_attention_heads: usize,
|
||||
pub use_cache: bool,
|
||||
pub is_encoder_decoder: bool,
|
||||
pub activation_function: candle_nn::Activation,
|
||||
pub d_model: usize,
|
||||
pub decoder_start_token_id: usize,
|
||||
pub scale_embedding: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
pub forced_eos_token_id: usize,
|
||||
pub share_encoder_decoder_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en/blob/main/config.json
|
||||
pub fn opus_mt_tc_big_fr_en() -> Self {
|
||||
Self {
|
||||
activation_function: candle_nn::Activation::Relu,
|
||||
d_model: 1024,
|
||||
decoder_attention_heads: 16,
|
||||
decoder_ffn_dim: 4096,
|
||||
decoder_layers: 6,
|
||||
decoder_start_token_id: 53016,
|
||||
decoder_vocab_size: Some(53017),
|
||||
encoder_attention_heads: 16,
|
||||
encoder_ffn_dim: 4096,
|
||||
encoder_layers: 6,
|
||||
eos_token_id: 43311,
|
||||
forced_eos_token_id: 43311,
|
||||
is_encoder_decoder: true,
|
||||
max_position_embeddings: 1024,
|
||||
pad_token_id: 53016,
|
||||
scale_embedding: true,
|
||||
share_encoder_decoder_embeddings: true,
|
||||
use_cache: true,
|
||||
vocab_size: 53017,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SinusoidalPositionalEmbedding {
|
||||
emb: Embedding,
|
||||
}
|
||||
|
||||
impl SinusoidalPositionalEmbedding {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dev = vb.device();
|
||||
let dtype = vb.dtype();
|
||||
let num_positions = cfg.max_position_embeddings;
|
||||
let dim = cfg.d_model;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, num_positions as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((num_positions, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let sin = freqs.sin()?;
|
||||
let cos = freqs.cos()?;
|
||||
let weights = Tensor::cat(&[&sin, &cos], 1)?.contiguous()?;
|
||||
let emb = Embedding::from_weights(weights)?;
|
||||
Ok(Self { emb })
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let seq_len = input_ids.dim(1)?;
|
||||
Tensor::arange(
|
||||
past_kv_len as u32,
|
||||
(past_kv_len + seq_len) as u32,
|
||||
input_ids.device(),
|
||||
)?
|
||||
.apply(&self.emb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
scaling: f64,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &Config, is_decoder: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let num_heads = if is_decoder {
|
||||
cfg.decoder_attention_heads
|
||||
} else {
|
||||
cfg.encoder_attention_heads
|
||||
};
|
||||
let embed_dim = cfg.d_model;
|
||||
let head_dim = embed_dim / num_heads;
|
||||
let scaling = (head_dim as f64).powf(-0.5);
|
||||
let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
|
||||
let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
scaling,
|
||||
num_heads,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
||||
tensor
|
||||
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, kv_states: Option<&Tensor>) -> Result<Tensor> {
|
||||
let is_cross_attn = kv_states.is_some();
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||
let (key_states, value_states) = match kv_states {
|
||||
None => {
|
||||
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
|
||||
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
Some(kv_states) => {
|
||||
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
|
||||
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
|
||||
let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
|
||||
let key_states = key_states.reshape(proj_shape)?;
|
||||
let value_states = value_states.reshape(proj_shape)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
// todo: attn_mask
|
||||
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_probs.matmul(&value_states)?;
|
||||
attn_output
|
||||
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
|
||||
.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderLayer {
|
||||
self_attn: Attention,
|
||||
self_attn_layer_norm: LayerNorm,
|
||||
activation_fn: candle_nn::Activation,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl EncoderLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
|
||||
let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let fc1 = linear(cfg.d_model, cfg.encoder_ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.encoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
self_attn_layer_norm,
|
||||
activation_fn: cfg.activation_function,
|
||||
fc1,
|
||||
fc2,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs =
|
||||
(self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?;
|
||||
let residual = &xs;
|
||||
let xs = xs
|
||||
.apply(&self.fc1)?
|
||||
.apply(&self.activation_fn)?
|
||||
.apply(&self.fc2)?;
|
||||
(xs + residual)?.apply(&self.final_layer_norm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
self_attn_layer_norm: LayerNorm,
|
||||
activation_fn: candle_nn::Activation,
|
||||
encoder_attn: Attention,
|
||||
encoder_attn_layer_norm: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
|
||||
let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?;
|
||||
let encoder_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
self_attn_layer_norm,
|
||||
activation_fn: cfg.activation_function,
|
||||
encoder_attn,
|
||||
encoder_attn_layer_norm,
|
||||
fc1,
|
||||
fc2,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, encoder_xs: Option<&Tensor>) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs =
|
||||
(self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?;
|
||||
let xs = match encoder_xs {
|
||||
None => xs,
|
||||
Some(encoder_xs) => {
|
||||
let residual = &xs;
|
||||
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?;
|
||||
(residual + xs)?.apply(&self.self_attn_layer_norm)?
|
||||
}
|
||||
};
|
||||
let residual = &xs;
|
||||
let xs = xs
|
||||
.apply(&self.fc1)?
|
||||
.apply(&self.activation_fn)?
|
||||
.apply(&self.fc2)?;
|
||||
(xs + residual)?.apply(&self.final_layer_norm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Encoder {
|
||||
embed_tokens: Embedding,
|
||||
embed_positions: SinusoidalPositionalEmbedding,
|
||||
layers: Vec<EncoderLayer>,
|
||||
embed_scale: Option<f64>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
|
||||
let mut layers = Vec::with_capacity(cfg.encoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for idx in 0..cfg.encoder_layers {
|
||||
let layer = EncoderLayer::new(cfg, vb_l.pp(idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let embed_scale = if cfg.scale_embedding {
|
||||
Some((cfg.d_model as f64).sqrt())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
embed_tokens: embed_tokens.clone(),
|
||||
embed_positions,
|
||||
layers,
|
||||
embed_scale,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embed_tokens)?;
|
||||
let xs = match self.embed_scale {
|
||||
None => xs,
|
||||
Some(scale) => (xs * scale)?,
|
||||
};
|
||||
let embed_pos = self
|
||||
.embed_positions
|
||||
.forward(&xs, past_kv_len)?
|
||||
.unsqueeze(0)?;
|
||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Decoder {
|
||||
embed_tokens: Embedding,
|
||||
embed_positions: SinusoidalPositionalEmbedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
embed_scale: Option<f64>,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
|
||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for idx in 0..cfg.decoder_layers {
|
||||
let layer = DecoderLayer::new(cfg, vb_l.pp(idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let embed_scale = if cfg.scale_embedding {
|
||||
Some((cfg.d_model as f64).sqrt())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
embed_tokens: embed_tokens.clone(),
|
||||
embed_positions,
|
||||
layers,
|
||||
embed_scale,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: Option<&Tensor>,
|
||||
past_kv_len: usize,
|
||||
) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embed_tokens)?;
|
||||
let xs = match self.embed_scale {
|
||||
None => xs,
|
||||
Some(scale) => (xs * scale)?,
|
||||
};
|
||||
let embed_pos = self
|
||||
.embed_positions
|
||||
.forward(&xs, past_kv_len)?
|
||||
.unsqueeze(0)?;
|
||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, encoder_xs)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Model {
|
||||
shared: Embedding,
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let encoder = Encoder::new(cfg, &shared, vb.pp("encoder"))?;
|
||||
let decoder = Decoder::new(cfg, &shared, vb.pp("decoder"))?;
|
||||
Ok(Self {
|
||||
shared,
|
||||
encoder,
|
||||
decoder,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MTModel {
|
||||
model: Model,
|
||||
final_logits_bias: Tensor,
|
||||
}
|
||||
|
||||
impl MTModel {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size);
|
||||
let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?;
|
||||
let model = Model::new(cfg, vb.pp("model"))?;
|
||||
Ok(Self {
|
||||
model,
|
||||
final_logits_bias,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encoder(&self) -> &Encoder {
|
||||
&self.model.encoder
|
||||
}
|
||||
|
||||
pub fn decoder(&self) -> &Decoder {
|
||||
&self.model.decoder
|
||||
}
|
||||
|
||||
pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
|
||||
self.model.decoder.forward(xs, Some(encoder_xs), 0)
|
||||
}
|
||||
}
|
@ -10,6 +10,7 @@ pub mod jina_bert;
|
||||
pub mod llama;
|
||||
pub mod llama2_c;
|
||||
pub mod llama2_c_weights;
|
||||
pub mod marian;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mpt;
|
||||
|
@ -14,6 +14,13 @@ impl Embedding {
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn from_weights(weights: Tensor) -> Result<Self> {
|
||||
let (_in_size, out_size) = weights.dims2()?;
|
||||
let inner = candle_nn::Embedding::new(weights, out_size);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn embeddings(&self) -> &Tensor {
|
||||
self.inner.embeddings()
|
||||
}
|
||||
|
Reference in New Issue
Block a user