mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Support sinusoidal embeddings in trocr. (#1690)
* Support sinusoidal embeddings in trocr. * Support tie-word-embeddings.
This commit is contained in:
@ -1,9 +1,12 @@
|
||||
use crate::models::vit::{Config, Embeddings, Encoder};
|
||||
use candle::{Result, Tensor};
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||
};
|
||||
|
||||
fn default_tie_word_embeddings() -> bool {
|
||||
true
|
||||
}
|
||||
fn default_use_learned_position_embeddings() -> bool {
|
||||
true
|
||||
}
|
||||
@ -32,6 +35,8 @@ pub struct TrOCRConfig {
|
||||
pub decoder_vocab_size: Option<usize>,
|
||||
#[serde(default = "default_use_learned_position_embeddings")]
|
||||
pub use_learned_position_embeddings: bool,
|
||||
#[serde(default = "default_tie_word_embeddings")]
|
||||
pub tie_word_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Default for TrOCRConfig {
|
||||
@ -58,6 +63,7 @@ impl Default for TrOCRConfig {
|
||||
eos_token_id: 2,
|
||||
decoder_vocab_size: Some(50265),
|
||||
use_learned_position_embeddings: true,
|
||||
tie_word_embeddings: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -78,17 +84,49 @@ impl TrOCRLearnedPositionalEmbedding {
|
||||
Ok(Self { offset, weights })
|
||||
}
|
||||
|
||||
fn new_sinusoidal(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
||||
// https://github.com/huggingface/transformers/blob/58e3d23e97078f361a533b9ec4a6a2de674ea52a/src/transformers/models/trocr/modeling_trocr.py#L81
|
||||
let embedding_dim = cfg.d_model;
|
||||
let half_dim = embedding_dim / 2;
|
||||
let num_positions = cfg.max_position_embeddings + cfg.pad_token_id + 1;
|
||||
let dev = vb.device();
|
||||
let inv_freq: Vec<_> = (0..half_dim)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let t = Tensor::arange(0u32, num_positions as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((num_positions, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let emb = Tensor::cat(&[freqs.sin()?, freqs.cos()?], 1)?;
|
||||
let emb = Tensor::cat(
|
||||
&[
|
||||
emb.narrow(0, 0, cfg.pad_token_id)?,
|
||||
Tensor::zeros((1, embedding_dim), DType::F32, dev)?,
|
||||
emb.narrow(0, cfg.pad_token_id + 1, cfg.max_position_embeddings)?,
|
||||
],
|
||||
0,
|
||||
)?
|
||||
.contiguous()?;
|
||||
let emb = Embedding::new(emb, embedding_dim);
|
||||
Ok(Self {
|
||||
offset: cfg.pad_token_id + 1,
|
||||
weights: emb,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
|
||||
let mut positions = Tensor::arange(
|
||||
let positions = Tensor::arange(
|
||||
past_key_values_length,
|
||||
seq_len as u32 + past_key_values_length,
|
||||
input_ids.device(),
|
||||
)?
|
||||
.expand((b_sz, seq_len))?;
|
||||
|
||||
positions =
|
||||
let positions =
|
||||
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
||||
self.weights.forward(&positions)
|
||||
}
|
||||
@ -229,11 +267,9 @@ impl TrOCRDecoderLayer {
|
||||
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
let activation_fn = candle_nn::Activation::Gelu;
|
||||
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
activation_fn,
|
||||
activation_fn: cfg.activation_function,
|
||||
self_attn_layer_norm,
|
||||
encoder_attn,
|
||||
encoder_attn_layer_norm,
|
||||
@ -294,10 +330,11 @@ impl TrOCRDecoder {
|
||||
let vb = vb.pp("decoder.model.decoder");
|
||||
|
||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||
if !cfg.use_learned_position_embeddings {
|
||||
candle::bail!("only models with use_learned_position_embeddings=true are supported")
|
||||
}
|
||||
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
||||
let embed_positions = if cfg.use_learned_position_embeddings {
|
||||
TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?
|
||||
} else {
|
||||
TrOCRLearnedPositionalEmbedding::new_sinusoidal(vb.pp("embed_positions"), cfg)?
|
||||
};
|
||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for idx in 0..cfg.decoder_layers {
|
||||
@ -386,8 +423,15 @@ pub struct TrOCRForCausalLM {
|
||||
impl TrOCRForCausalLM {
|
||||
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
||||
let output_projection =
|
||||
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
|
||||
let output_projection = if decoder_cfg.tie_word_embeddings {
|
||||
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(
|
||||
decoder_cfg.d_model,
|
||||
decoder_cfg.vocab_size,
|
||||
vb.pp("decoder.output_projection"),
|
||||
)?
|
||||
};
|
||||
Ok(Self {
|
||||
decoder,
|
||||
output_projection,
|
||||
|
Reference in New Issue
Block a user