Use the repo config for trocr rather than hardcoding it + small tweaks. (#1689)

* Use the repo config for trocr rather than hardcoding it + small tweaks.

* Add support for the printed models.

* Fail with an appropriate error message on missing position embeddings.
This commit is contained in:
Laurent Mazare
2024-02-10 13:15:03 +01:00
committed by GitHub
parent 67589791d2
commit 42ce593ec6
3 changed files with 77 additions and 52 deletions

View File

@ -10,15 +10,36 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor}; use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::models::trocr; use candle_transformers::models::{trocr, vit};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod image_processor; mod image_processor;
#[derive(Clone, Debug, Copy, ValueEnum)] #[derive(Clone, Debug, Copy, ValueEnum)]
enum Which { enum Which {
Base, #[value(name = "base")]
Large, BaseHandwritten,
#[value(name = "large")]
LargeHandwritten,
BasePrinted,
LargePrinted,
}
impl Which {
fn repo_and_branch_name(&self) -> (&str, &str) {
match self {
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
}
}
}
#[derive(Debug, Clone, serde::Deserialize)]
struct Config {
encoder: vit::Config,
decoder: trocr::TrOCRConfig,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -34,63 +55,64 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Text to be translated /// The image file to be processed.
#[arg(long)] #[arg(long)]
image: String, image: String,
/// Tokenization config.
#[arg(long)]
tokenizer: Option<String>,
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse(); let args = Args::parse();
let api = hf_hub::api::sync::Api::new()?;
let tokenizer_dec = { let mut tokenizer_dec = {
let tokenizer = Api::new()? let tokenizer_file = match args.tokenizer {
.model(String::from("ToluClassics/candle-trocr-tokenizer")) None => api
.get("tokenizer.json")?; .model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?,
Tokenizer::from_file(&tokenizer).map_err(E::msg)? Some(tokenizer) => std::path::PathBuf::from(tokenizer),
};
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
TokenOutputStream::new(tokenizer)
}; };
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = { let vb = {
let model = match args.model { let model = match args.model {
Some(model) => std::path::PathBuf::from(model), Some(model) => std::path::PathBuf::from(model),
None => match args.which { None => {
Which::Base => Api::new()? let (repo, branch) = args.which.repo_and_branch_name();
.repo(hf_hub::Repo::with_revision( api.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-base-handwritten".to_string(), repo.to_string(),
hf_hub::RepoType::Model, hf_hub::RepoType::Model,
"refs/pr/3".to_string(), branch.to_string(),
)) ))
.get("model.safetensors")?, .get("model.safetensors")?
Which::Large => Api::new()? }
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-large-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/6".to_string(),
))
.get("model.safetensors")?,
},
}; };
println!("model: {:?}", model); println!("model: {:?}", model);
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? } unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
}; };
let encoder_config = match args.which { let (encoder_config, decoder_config) = {
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(), let (repo, branch) = args.which.repo_and_branch_name();
Which::Large => { let config_filename = api
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten() .repo(hf_hub::Repo::with_revision(
} repo.to_string(),
hf_hub::RepoType::Model,
branch.to_string(),
))
.get("config.json")?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
(config.encoder, config.decoder)
}; };
let decoder_config = trocr::TrOCRConfig::default();
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?; let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
let config = image_processor::ProcessorConfig::default(); let processor_config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&config); let processor = image_processor::ViTImageProcessor::new(&processor_config);
let image = vec![args.image.as_str()]; let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?; let image = processor.preprocess(image)?;

View File

@ -3,13 +3,16 @@ use candle::{Result, Tensor};
use candle_nn::{ use candle_nn::{
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder, embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
}; };
use serde::Deserialize;
#[derive(Debug, Clone, PartialEq, Deserialize)] fn default_use_learned_position_embeddings() -> bool {
true
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct TrOCRConfig { pub struct TrOCRConfig {
pub vocab_size: usize, pub vocab_size: usize,
pub d_model: usize, pub d_model: usize,
pub hidden_size: usize, pub cross_attention_hidden_size: usize,
pub decoder_layers: usize, pub decoder_layers: usize,
pub decoder_attention_heads: usize, pub decoder_attention_heads: usize,
pub decoder_ffn_dim: usize, pub decoder_ffn_dim: usize,
@ -23,13 +26,12 @@ pub struct TrOCRConfig {
pub decoder_layerdrop: f64, pub decoder_layerdrop: f64,
pub use_cache: bool, pub use_cache: bool,
pub scale_embedding: bool, pub scale_embedding: bool,
pub use_learned_position_embeddings: bool,
pub layernorm_embedding: bool,
pub pad_token_id: usize, pub pad_token_id: usize,
pub bos_token_id: usize, pub bos_token_id: usize,
pub eos_token_id: u32, pub eos_token_id: u32,
pub num_attention_heads: usize,
pub decoder_vocab_size: Option<usize>, pub decoder_vocab_size: Option<usize>,
#[serde(default = "default_use_learned_position_embeddings")]
pub use_learned_position_embeddings: bool,
} }
impl Default for TrOCRConfig { impl Default for TrOCRConfig {
@ -37,7 +39,7 @@ impl Default for TrOCRConfig {
Self { Self {
vocab_size: 50265, vocab_size: 50265,
d_model: 1024, d_model: 1024,
hidden_size: 768, cross_attention_hidden_size: 768,
decoder_layers: 12, decoder_layers: 12,
decoder_attention_heads: 16, decoder_attention_heads: 16,
decoder_ffn_dim: 4096, decoder_ffn_dim: 4096,
@ -51,13 +53,11 @@ impl Default for TrOCRConfig {
decoder_layerdrop: 0.0, decoder_layerdrop: 0.0,
use_cache: true, use_cache: true,
scale_embedding: false, scale_embedding: false,
use_learned_position_embeddings: true,
layernorm_embedding: true,
pad_token_id: 1, pad_token_id: 1,
bos_token_id: 0, bos_token_id: 0,
eos_token_id: 2, eos_token_id: 2,
num_attention_heads: 12,
decoder_vocab_size: Some(50265), decoder_vocab_size: Some(50265),
use_learned_position_embeddings: true,
} }
} }
} }
@ -221,8 +221,8 @@ impl TrOCRDecoderLayer {
let encoder_attn = TrOCRAttention::load( let encoder_attn = TrOCRAttention::load(
vb.pp("encoder_attn"), vb.pp("encoder_attn"),
cfg, cfg,
Some(cfg.hidden_size), Some(cfg.cross_attention_hidden_size),
Some(cfg.hidden_size), Some(cfg.cross_attention_hidden_size),
)?; )?;
let encoder_attn_layer_norm = let encoder_attn_layer_norm =
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?; layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
@ -294,6 +294,9 @@ impl TrOCRDecoder {
let vb = vb.pp("decoder.model.decoder"); let vb = vb.pp("decoder.model.decoder");
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?; let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
if !cfg.use_learned_position_embeddings {
candle::bail!("only models with use_learned_position_embeddings=true are supported")
}
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?; let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
let mut layers = Vec::with_capacity(cfg.decoder_layers); let mut layers = Vec::with_capacity(cfg.decoder_layers);
let vb_l = vb.pp("layers"); let vb_l = vb.pp("layers");

View File

@ -3,7 +3,7 @@ use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, VarBuilder}; use candle_nn::{layer_norm, LayerNorm, VarBuilder};
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py // https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
#[derive(Debug, Clone)] #[derive(Debug, Clone, serde::Deserialize)]
pub struct Config { pub struct Config {
pub hidden_size: usize, pub hidden_size: usize,
pub num_hidden_layers: usize, pub num_hidden_layers: usize,