mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Support sinusoidal embeddings in trocr. (#1690)
* Support sinusoidal embeddings in trocr. * Support tie-word-embeddings.
This commit is contained in:
@ -1,9 +1,12 @@
|
|||||||
use crate::models::vit::{Config, Embeddings, Encoder};
|
use crate::models::vit::{Config, Embeddings, Encoder};
|
||||||
use candle::{Result, Tensor};
|
use candle::{DType, Result, Tensor};
|
||||||
use candle_nn::{
|
use candle_nn::{
|
||||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
fn default_tie_word_embeddings() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
fn default_use_learned_position_embeddings() -> bool {
|
fn default_use_learned_position_embeddings() -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
@ -32,6 +35,8 @@ pub struct TrOCRConfig {
|
|||||||
pub decoder_vocab_size: Option<usize>,
|
pub decoder_vocab_size: Option<usize>,
|
||||||
#[serde(default = "default_use_learned_position_embeddings")]
|
#[serde(default = "default_use_learned_position_embeddings")]
|
||||||
pub use_learned_position_embeddings: bool,
|
pub use_learned_position_embeddings: bool,
|
||||||
|
#[serde(default = "default_tie_word_embeddings")]
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for TrOCRConfig {
|
impl Default for TrOCRConfig {
|
||||||
@ -58,6 +63,7 @@ impl Default for TrOCRConfig {
|
|||||||
eos_token_id: 2,
|
eos_token_id: 2,
|
||||||
decoder_vocab_size: Some(50265),
|
decoder_vocab_size: Some(50265),
|
||||||
use_learned_position_embeddings: true,
|
use_learned_position_embeddings: true,
|
||||||
|
tie_word_embeddings: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -78,17 +84,49 @@ impl TrOCRLearnedPositionalEmbedding {
|
|||||||
Ok(Self { offset, weights })
|
Ok(Self { offset, weights })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn new_sinusoidal(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
||||||
|
// https://github.com/huggingface/transformers/blob/58e3d23e97078f361a533b9ec4a6a2de674ea52a/src/transformers/models/trocr/modeling_trocr.py#L81
|
||||||
|
let embedding_dim = cfg.d_model;
|
||||||
|
let half_dim = embedding_dim / 2;
|
||||||
|
let num_positions = cfg.max_position_embeddings + cfg.pad_token_id + 1;
|
||||||
|
let dev = vb.device();
|
||||||
|
let inv_freq: Vec<_> = (0..half_dim)
|
||||||
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32))
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||||
|
let t = Tensor::arange(0u32, num_positions as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((num_positions, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
let emb = Tensor::cat(&[freqs.sin()?, freqs.cos()?], 1)?;
|
||||||
|
let emb = Tensor::cat(
|
||||||
|
&[
|
||||||
|
emb.narrow(0, 0, cfg.pad_token_id)?,
|
||||||
|
Tensor::zeros((1, embedding_dim), DType::F32, dev)?,
|
||||||
|
emb.narrow(0, cfg.pad_token_id + 1, cfg.max_position_embeddings)?,
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
)?
|
||||||
|
.contiguous()?;
|
||||||
|
let emb = Embedding::new(emb, embedding_dim);
|
||||||
|
Ok(Self {
|
||||||
|
offset: cfg.pad_token_id + 1,
|
||||||
|
weights: emb,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||||
|
|
||||||
let mut positions = Tensor::arange(
|
let positions = Tensor::arange(
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
seq_len as u32 + past_key_values_length,
|
seq_len as u32 + past_key_values_length,
|
||||||
input_ids.device(),
|
input_ids.device(),
|
||||||
)?
|
)?
|
||||||
.expand((b_sz, seq_len))?;
|
.expand((b_sz, seq_len))?;
|
||||||
|
|
||||||
positions =
|
let positions =
|
||||||
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
||||||
self.weights.forward(&positions)
|
self.weights.forward(&positions)
|
||||||
}
|
}
|
||||||
@ -229,11 +267,9 @@ impl TrOCRDecoderLayer {
|
|||||||
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||||
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
||||||
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
||||||
let activation_fn = candle_nn::Activation::Gelu;
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
self_attn,
|
||||||
activation_fn,
|
activation_fn: cfg.activation_function,
|
||||||
self_attn_layer_norm,
|
self_attn_layer_norm,
|
||||||
encoder_attn,
|
encoder_attn,
|
||||||
encoder_attn_layer_norm,
|
encoder_attn_layer_norm,
|
||||||
@ -294,10 +330,11 @@ impl TrOCRDecoder {
|
|||||||
let vb = vb.pp("decoder.model.decoder");
|
let vb = vb.pp("decoder.model.decoder");
|
||||||
|
|
||||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||||
if !cfg.use_learned_position_embeddings {
|
let embed_positions = if cfg.use_learned_position_embeddings {
|
||||||
candle::bail!("only models with use_learned_position_embeddings=true are supported")
|
TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?
|
||||||
}
|
} else {
|
||||||
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
TrOCRLearnedPositionalEmbedding::new_sinusoidal(vb.pp("embed_positions"), cfg)?
|
||||||
|
};
|
||||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||||
let vb_l = vb.pp("layers");
|
let vb_l = vb.pp("layers");
|
||||||
for idx in 0..cfg.decoder_layers {
|
for idx in 0..cfg.decoder_layers {
|
||||||
@ -386,8 +423,15 @@ pub struct TrOCRForCausalLM {
|
|||||||
impl TrOCRForCausalLM {
|
impl TrOCRForCausalLM {
|
||||||
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
||||||
let output_projection =
|
let output_projection = if decoder_cfg.tie_word_embeddings {
|
||||||
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
|
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None)
|
||||||
|
} else {
|
||||||
|
candle_nn::linear_no_bias(
|
||||||
|
decoder_cfg.d_model,
|
||||||
|
decoder_cfg.vocab_size,
|
||||||
|
vb.pp("decoder.output_projection"),
|
||||||
|
)?
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
decoder,
|
decoder,
|
||||||
output_projection,
|
output_projection,
|
||||||
|
Reference in New Issue
Block a user