From bf20cc854c4691c04c045dde7e4a25e63eca3d0e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 10 Feb 2024 15:17:51 +0100 Subject: [PATCH] Support sinusoidal embeddings in trocr. (#1690) * Support sinusoidal embeddings in trocr. * Support tie-word-embeddings. --- candle-transformers/src/models/trocr.rs | 68 ++++++++++++++++++++----- 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/candle-transformers/src/models/trocr.rs b/candle-transformers/src/models/trocr.rs index 13cdaa9c..d17eda17 100644 --- a/candle-transformers/src/models/trocr.rs +++ b/candle-transformers/src/models/trocr.rs @@ -1,9 +1,12 @@ use crate::models::vit::{Config, Embeddings, Encoder}; -use candle::{Result, Tensor}; +use candle::{DType, Result, Tensor}; use candle_nn::{ embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder, }; +fn default_tie_word_embeddings() -> bool { + true +} fn default_use_learned_position_embeddings() -> bool { true } @@ -32,6 +35,8 @@ pub struct TrOCRConfig { pub decoder_vocab_size: Option, #[serde(default = "default_use_learned_position_embeddings")] pub use_learned_position_embeddings: bool, + #[serde(default = "default_tie_word_embeddings")] + pub tie_word_embeddings: bool, } impl Default for TrOCRConfig { @@ -58,6 +63,7 @@ impl Default for TrOCRConfig { eos_token_id: 2, decoder_vocab_size: Some(50265), use_learned_position_embeddings: true, + tie_word_embeddings: true, } } } @@ -78,17 +84,49 @@ impl TrOCRLearnedPositionalEmbedding { Ok(Self { offset, weights }) } + fn new_sinusoidal(vb: VarBuilder, cfg: &TrOCRConfig) -> Result { + // https://github.com/huggingface/transformers/blob/58e3d23e97078f361a533b9ec4a6a2de674ea52a/src/transformers/models/trocr/modeling_trocr.py#L81 + let embedding_dim = cfg.d_model; + let half_dim = embedding_dim / 2; + let num_positions = cfg.max_position_embeddings + cfg.pad_token_id + 1; + let dev = vb.device(); + let inv_freq: Vec<_> = (0..half_dim) + .map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, num_positions as u32, dev)? + .to_dtype(DType::F32)? + .reshape((num_positions, 1))?; + let freqs = t.matmul(&inv_freq)?; + let emb = Tensor::cat(&[freqs.sin()?, freqs.cos()?], 1)?; + let emb = Tensor::cat( + &[ + emb.narrow(0, 0, cfg.pad_token_id)?, + Tensor::zeros((1, embedding_dim), DType::F32, dev)?, + emb.narrow(0, cfg.pad_token_id + 1, cfg.max_position_embeddings)?, + ], + 0, + )? + .contiguous()?; + let emb = Embedding::new(emb, embedding_dim); + Ok(Self { + offset: cfg.pad_token_id + 1, + weights: emb, + }) + } + fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result { let (b_sz, seq_len) = input_ids.dims2()?; - let mut positions = Tensor::arange( + let positions = Tensor::arange( past_key_values_length, seq_len as u32 + past_key_values_length, input_ids.device(), )? .expand((b_sz, seq_len))?; - positions = + let positions = positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?; self.weights.forward(&positions) } @@ -229,11 +267,9 @@ impl TrOCRDecoderLayer { let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?; let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?; let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?; - let activation_fn = candle_nn::Activation::Gelu; - Ok(Self { self_attn, - activation_fn, + activation_fn: cfg.activation_function, self_attn_layer_norm, encoder_attn, encoder_attn_layer_norm, @@ -294,10 +330,11 @@ impl TrOCRDecoder { let vb = vb.pp("decoder.model.decoder"); let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?; - if !cfg.use_learned_position_embeddings { - candle::bail!("only models with use_learned_position_embeddings=true are supported") - } - let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?; + let embed_positions = if cfg.use_learned_position_embeddings { + TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)? + } else { + TrOCRLearnedPositionalEmbedding::new_sinusoidal(vb.pp("embed_positions"), cfg)? + }; let mut layers = Vec::with_capacity(cfg.decoder_layers); let vb_l = vb.pp("layers"); for idx in 0..cfg.decoder_layers { @@ -386,8 +423,15 @@ pub struct TrOCRForCausalLM { impl TrOCRForCausalLM { pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result { let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?; - let output_projection = - candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None); + let output_projection = if decoder_cfg.tie_word_embeddings { + candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None) + } else { + candle_nn::linear_no_bias( + decoder_cfg.d_model, + decoder_cfg.vocab_size, + vb.pp("decoder.output_projection"), + )? + }; Ok(Self { decoder, output_projection,