Use the repo config for trocr rather than hardcoding it + small tweaks. (#1689)

* Use the repo config for trocr rather than hardcoding it + small tweaks.

* Add support for the printed models.

* Fail with an appropriate error message on missing position embeddings.
This commit is contained in:
Laurent Mazare
2024-02-10 13:15:03 +01:00
committed by GitHub
parent 67589791d2
commit 42ce593ec6
3 changed files with 77 additions and 52 deletions

View File

@ -3,13 +3,16 @@ use candle::{Result, Tensor};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
};
use serde::Deserialize;
#[derive(Debug, Clone, PartialEq, Deserialize)]
fn default_use_learned_position_embeddings() -> bool {
true
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct TrOCRConfig {
pub vocab_size: usize,
pub d_model: usize,
pub hidden_size: usize,
pub cross_attention_hidden_size: usize,
pub decoder_layers: usize,
pub decoder_attention_heads: usize,
pub decoder_ffn_dim: usize,
@ -23,13 +26,12 @@ pub struct TrOCRConfig {
pub decoder_layerdrop: f64,
pub use_cache: bool,
pub scale_embedding: bool,
pub use_learned_position_embeddings: bool,
pub layernorm_embedding: bool,
pub pad_token_id: usize,
pub bos_token_id: usize,
pub eos_token_id: u32,
pub num_attention_heads: usize,
pub decoder_vocab_size: Option<usize>,
#[serde(default = "default_use_learned_position_embeddings")]
pub use_learned_position_embeddings: bool,
}
impl Default for TrOCRConfig {
@ -37,7 +39,7 @@ impl Default for TrOCRConfig {
Self {
vocab_size: 50265,
d_model: 1024,
hidden_size: 768,
cross_attention_hidden_size: 768,
decoder_layers: 12,
decoder_attention_heads: 16,
decoder_ffn_dim: 4096,
@ -51,13 +53,11 @@ impl Default for TrOCRConfig {
decoder_layerdrop: 0.0,
use_cache: true,
scale_embedding: false,
use_learned_position_embeddings: true,
layernorm_embedding: true,
pad_token_id: 1,
bos_token_id: 0,
eos_token_id: 2,
num_attention_heads: 12,
decoder_vocab_size: Some(50265),
use_learned_position_embeddings: true,
}
}
}
@ -221,8 +221,8 @@ impl TrOCRDecoderLayer {
let encoder_attn = TrOCRAttention::load(
vb.pp("encoder_attn"),
cfg,
Some(cfg.hidden_size),
Some(cfg.hidden_size),
Some(cfg.cross_attention_hidden_size),
Some(cfg.cross_attention_hidden_size),
)?;
let encoder_attn_layer_norm =
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
@ -294,6 +294,9 @@ impl TrOCRDecoder {
let vb = vb.pp("decoder.model.decoder");
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
if !cfg.use_learned_position_embeddings {
candle::bail!("only models with use_learned_position_embeddings=true are supported")
}
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
let mut layers = Vec::with_capacity(cfg.decoder_layers);
let vb_l = vb.pp("layers");

View File

@ -3,7 +3,7 @@ use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
#[derive(Debug, Clone)]
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub hidden_size: usize,
pub num_hidden_layers: usize,