mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use the repo config for trocr rather than hardcoding it + small tweaks. (#1689)
* Use the repo config for trocr rather than hardcoding it + small tweaks. * Add support for the printed models. * Fail with an appropriate error message on missing position embeddings.
This commit is contained in:
@ -3,13 +3,16 @@ use candle::{Result, Tensor};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
fn default_use_learned_position_embeddings() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct TrOCRConfig {
|
||||
pub vocab_size: usize,
|
||||
pub d_model: usize,
|
||||
pub hidden_size: usize,
|
||||
pub cross_attention_hidden_size: usize,
|
||||
pub decoder_layers: usize,
|
||||
pub decoder_attention_heads: usize,
|
||||
pub decoder_ffn_dim: usize,
|
||||
@ -23,13 +26,12 @@ pub struct TrOCRConfig {
|
||||
pub decoder_layerdrop: f64,
|
||||
pub use_cache: bool,
|
||||
pub scale_embedding: bool,
|
||||
pub use_learned_position_embeddings: bool,
|
||||
pub layernorm_embedding: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: u32,
|
||||
pub num_attention_heads: usize,
|
||||
pub decoder_vocab_size: Option<usize>,
|
||||
#[serde(default = "default_use_learned_position_embeddings")]
|
||||
pub use_learned_position_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Default for TrOCRConfig {
|
||||
@ -37,7 +39,7 @@ impl Default for TrOCRConfig {
|
||||
Self {
|
||||
vocab_size: 50265,
|
||||
d_model: 1024,
|
||||
hidden_size: 768,
|
||||
cross_attention_hidden_size: 768,
|
||||
decoder_layers: 12,
|
||||
decoder_attention_heads: 16,
|
||||
decoder_ffn_dim: 4096,
|
||||
@ -51,13 +53,11 @@ impl Default for TrOCRConfig {
|
||||
decoder_layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
scale_embedding: false,
|
||||
use_learned_position_embeddings: true,
|
||||
layernorm_embedding: true,
|
||||
pad_token_id: 1,
|
||||
bos_token_id: 0,
|
||||
eos_token_id: 2,
|
||||
num_attention_heads: 12,
|
||||
decoder_vocab_size: Some(50265),
|
||||
use_learned_position_embeddings: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -221,8 +221,8 @@ impl TrOCRDecoderLayer {
|
||||
let encoder_attn = TrOCRAttention::load(
|
||||
vb.pp("encoder_attn"),
|
||||
cfg,
|
||||
Some(cfg.hidden_size),
|
||||
Some(cfg.hidden_size),
|
||||
Some(cfg.cross_attention_hidden_size),
|
||||
Some(cfg.cross_attention_hidden_size),
|
||||
)?;
|
||||
let encoder_attn_layer_norm =
|
||||
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
@ -294,6 +294,9 @@ impl TrOCRDecoder {
|
||||
let vb = vb.pp("decoder.model.decoder");
|
||||
|
||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||
if !cfg.use_learned_position_embeddings {
|
||||
candle::bail!("only models with use_learned_position_embeddings=true are supported")
|
||||
}
|
||||
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
|
@ -3,7 +3,7 @@ use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
|
Reference in New Issue
Block a user