mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use the repo config for trocr rather than hardcoding it + small tweaks. (#1689)
* Use the repo config for trocr rather than hardcoding it + small tweaks. * Add support for the printed models. * Fail with an appropriate error message on missing position embeddings.
This commit is contained in:
@ -10,15 +10,36 @@ use clap::{Parser, ValueEnum};
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::trocr;
|
||||
use candle_transformers::models::{trocr, vit};
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
mod image_processor;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Large,
|
||||
#[value(name = "base")]
|
||||
BaseHandwritten,
|
||||
#[value(name = "large")]
|
||||
LargeHandwritten,
|
||||
BasePrinted,
|
||||
LargePrinted,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn repo_and_branch_name(&self) -> (&str, &str) {
|
||||
match self {
|
||||
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
|
||||
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
|
||||
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
|
||||
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
struct Config {
|
||||
encoder: vit::Config,
|
||||
decoder: trocr::TrOCRConfig,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -34,63 +55,64 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Text to be translated
|
||||
/// The image file to be processed.
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Tokenization config.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = Api::new()?
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?;
|
||||
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
let mut tokenizer_dec = {
|
||||
let tokenizer_file = match args.tokenizer {
|
||||
None => api
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?,
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
|
||||
TokenOutputStream::new(tokenizer)
|
||||
};
|
||||
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-base-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/3".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Large => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-large-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/6".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
None => {
|
||||
let (repo, branch) = args.which.repo_and_branch_name();
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
branch.to_string(),
|
||||
))
|
||||
.get("model.safetensors")?
|
||||
}
|
||||
};
|
||||
println!("model: {:?}", model);
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||
};
|
||||
|
||||
let encoder_config = match args.which {
|
||||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||
Which::Large => {
|
||||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||
}
|
||||
let (encoder_config, decoder_config) = {
|
||||
let (repo, branch) = args.which.repo_and_branch_name();
|
||||
let config_filename = api
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
branch.to_string(),
|
||||
))
|
||||
.get("config.json")?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
(config.encoder, config.decoder)
|
||||
};
|
||||
|
||||
let decoder_config = trocr::TrOCRConfig::default();
|
||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||
|
||||
let config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||
let processor_config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
||||
|
||||
let image = vec![args.image.as_str()];
|
||||
let image = processor.preprocess(image)?;
|
||||
|
@ -3,13 +3,16 @@ use candle::{Result, Tensor};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
fn default_use_learned_position_embeddings() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct TrOCRConfig {
|
||||
pub vocab_size: usize,
|
||||
pub d_model: usize,
|
||||
pub hidden_size: usize,
|
||||
pub cross_attention_hidden_size: usize,
|
||||
pub decoder_layers: usize,
|
||||
pub decoder_attention_heads: usize,
|
||||
pub decoder_ffn_dim: usize,
|
||||
@ -23,13 +26,12 @@ pub struct TrOCRConfig {
|
||||
pub decoder_layerdrop: f64,
|
||||
pub use_cache: bool,
|
||||
pub scale_embedding: bool,
|
||||
pub use_learned_position_embeddings: bool,
|
||||
pub layernorm_embedding: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: u32,
|
||||
pub num_attention_heads: usize,
|
||||
pub decoder_vocab_size: Option<usize>,
|
||||
#[serde(default = "default_use_learned_position_embeddings")]
|
||||
pub use_learned_position_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Default for TrOCRConfig {
|
||||
@ -37,7 +39,7 @@ impl Default for TrOCRConfig {
|
||||
Self {
|
||||
vocab_size: 50265,
|
||||
d_model: 1024,
|
||||
hidden_size: 768,
|
||||
cross_attention_hidden_size: 768,
|
||||
decoder_layers: 12,
|
||||
decoder_attention_heads: 16,
|
||||
decoder_ffn_dim: 4096,
|
||||
@ -51,13 +53,11 @@ impl Default for TrOCRConfig {
|
||||
decoder_layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
scale_embedding: false,
|
||||
use_learned_position_embeddings: true,
|
||||
layernorm_embedding: true,
|
||||
pad_token_id: 1,
|
||||
bos_token_id: 0,
|
||||
eos_token_id: 2,
|
||||
num_attention_heads: 12,
|
||||
decoder_vocab_size: Some(50265),
|
||||
use_learned_position_embeddings: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -221,8 +221,8 @@ impl TrOCRDecoderLayer {
|
||||
let encoder_attn = TrOCRAttention::load(
|
||||
vb.pp("encoder_attn"),
|
||||
cfg,
|
||||
Some(cfg.hidden_size),
|
||||
Some(cfg.hidden_size),
|
||||
Some(cfg.cross_attention_hidden_size),
|
||||
Some(cfg.cross_attention_hidden_size),
|
||||
)?;
|
||||
let encoder_attn_layer_norm =
|
||||
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
@ -294,6 +294,9 @@ impl TrOCRDecoder {
|
||||
let vb = vb.pp("decoder.model.decoder");
|
||||
|
||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||
if !cfg.use_learned_position_embeddings {
|
||||
candle::bail!("only models with use_learned_position_embeddings=true are supported")
|
||||
}
|
||||
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
|
@ -3,7 +3,7 @@ use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
|
Reference in New Issue
Block a user