mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Use the repo config for trocr rather than hardcoding it + small tweaks. (#1689)
* Use the repo config for trocr rather than hardcoding it + small tweaks. * Add support for the printed models. * Fail with an appropriate error message on missing position embeddings.
This commit is contained in:
@ -10,15 +10,36 @@ use clap::{Parser, ValueEnum};
|
|||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::models::trocr;
|
use candle_transformers::models::{trocr, vit};
|
||||||
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
mod image_processor;
|
mod image_processor;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
Base,
|
#[value(name = "base")]
|
||||||
Large,
|
BaseHandwritten,
|
||||||
|
#[value(name = "large")]
|
||||||
|
LargeHandwritten,
|
||||||
|
BasePrinted,
|
||||||
|
LargePrinted,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn repo_and_branch_name(&self) -> (&str, &str) {
|
||||||
|
match self {
|
||||||
|
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
|
||||||
|
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
|
||||||
|
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
|
||||||
|
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
|
struct Config {
|
||||||
|
encoder: vit::Config,
|
||||||
|
decoder: trocr::TrOCRConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -34,63 +55,64 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// Text to be translated
|
/// The image file to be processed.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
image: String,
|
image: String,
|
||||||
|
|
||||||
|
/// Tokenization config.
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
use hf_hub::api::sync::Api;
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
|
||||||
let tokenizer_dec = {
|
let mut tokenizer_dec = {
|
||||||
let tokenizer = Api::new()?
|
let tokenizer_file = match args.tokenizer {
|
||||||
|
None => api
|
||||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||||
.get("tokenizer.json")?;
|
.get("tokenizer.json")?,
|
||||||
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
};
|
||||||
|
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
|
||||||
|
TokenOutputStream::new(tokenizer)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let vb = {
|
let vb = {
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
None => match args.which {
|
None => {
|
||||||
Which::Base => Api::new()?
|
let (repo, branch) = args.which.repo_and_branch_name();
|
||||||
.repo(hf_hub::Repo::with_revision(
|
api.repo(hf_hub::Repo::with_revision(
|
||||||
"microsoft/trocr-base-handwritten".to_string(),
|
repo.to_string(),
|
||||||
hf_hub::RepoType::Model,
|
hf_hub::RepoType::Model,
|
||||||
"refs/pr/3".to_string(),
|
branch.to_string(),
|
||||||
))
|
))
|
||||||
.get("model.safetensors")?,
|
.get("model.safetensors")?
|
||||||
Which::Large => Api::new()?
|
}
|
||||||
.repo(hf_hub::Repo::with_revision(
|
|
||||||
"microsoft/trocr-large-handwritten".to_string(),
|
|
||||||
hf_hub::RepoType::Model,
|
|
||||||
"refs/pr/6".to_string(),
|
|
||||||
))
|
|
||||||
.get("model.safetensors")?,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
println!("model: {:?}", model);
|
println!("model: {:?}", model);
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||||
};
|
};
|
||||||
|
|
||||||
let encoder_config = match args.which {
|
let (encoder_config, decoder_config) = {
|
||||||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
let (repo, branch) = args.which.repo_and_branch_name();
|
||||||
Which::Large => {
|
let config_filename = api
|
||||||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
.repo(hf_hub::Repo::with_revision(
|
||||||
}
|
repo.to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
branch.to_string(),
|
||||||
|
))
|
||||||
|
.get("config.json")?;
|
||||||
|
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
|
(config.encoder, config.decoder)
|
||||||
};
|
};
|
||||||
|
|
||||||
let decoder_config = trocr::TrOCRConfig::default();
|
|
||||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||||
|
|
||||||
let config = image_processor::ProcessorConfig::default();
|
let processor_config = image_processor::ProcessorConfig::default();
|
||||||
let processor = image_processor::ViTImageProcessor::new(&config);
|
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
||||||
|
|
||||||
let image = vec![args.image.as_str()];
|
let image = vec![args.image.as_str()];
|
||||||
let image = processor.preprocess(image)?;
|
let image = processor.preprocess(image)?;
|
||||||
|
@ -3,13 +3,16 @@ use candle::{Result, Tensor};
|
|||||||
use candle_nn::{
|
use candle_nn::{
|
||||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
fn default_use_learned_position_embeddings() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||||
pub struct TrOCRConfig {
|
pub struct TrOCRConfig {
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
pub d_model: usize,
|
pub d_model: usize,
|
||||||
pub hidden_size: usize,
|
pub cross_attention_hidden_size: usize,
|
||||||
pub decoder_layers: usize,
|
pub decoder_layers: usize,
|
||||||
pub decoder_attention_heads: usize,
|
pub decoder_attention_heads: usize,
|
||||||
pub decoder_ffn_dim: usize,
|
pub decoder_ffn_dim: usize,
|
||||||
@ -23,13 +26,12 @@ pub struct TrOCRConfig {
|
|||||||
pub decoder_layerdrop: f64,
|
pub decoder_layerdrop: f64,
|
||||||
pub use_cache: bool,
|
pub use_cache: bool,
|
||||||
pub scale_embedding: bool,
|
pub scale_embedding: bool,
|
||||||
pub use_learned_position_embeddings: bool,
|
|
||||||
pub layernorm_embedding: bool,
|
|
||||||
pub pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
pub bos_token_id: usize,
|
pub bos_token_id: usize,
|
||||||
pub eos_token_id: u32,
|
pub eos_token_id: u32,
|
||||||
pub num_attention_heads: usize,
|
|
||||||
pub decoder_vocab_size: Option<usize>,
|
pub decoder_vocab_size: Option<usize>,
|
||||||
|
#[serde(default = "default_use_learned_position_embeddings")]
|
||||||
|
pub use_learned_position_embeddings: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for TrOCRConfig {
|
impl Default for TrOCRConfig {
|
||||||
@ -37,7 +39,7 @@ impl Default for TrOCRConfig {
|
|||||||
Self {
|
Self {
|
||||||
vocab_size: 50265,
|
vocab_size: 50265,
|
||||||
d_model: 1024,
|
d_model: 1024,
|
||||||
hidden_size: 768,
|
cross_attention_hidden_size: 768,
|
||||||
decoder_layers: 12,
|
decoder_layers: 12,
|
||||||
decoder_attention_heads: 16,
|
decoder_attention_heads: 16,
|
||||||
decoder_ffn_dim: 4096,
|
decoder_ffn_dim: 4096,
|
||||||
@ -51,13 +53,11 @@ impl Default for TrOCRConfig {
|
|||||||
decoder_layerdrop: 0.0,
|
decoder_layerdrop: 0.0,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
scale_embedding: false,
|
scale_embedding: false,
|
||||||
use_learned_position_embeddings: true,
|
|
||||||
layernorm_embedding: true,
|
|
||||||
pad_token_id: 1,
|
pad_token_id: 1,
|
||||||
bos_token_id: 0,
|
bos_token_id: 0,
|
||||||
eos_token_id: 2,
|
eos_token_id: 2,
|
||||||
num_attention_heads: 12,
|
|
||||||
decoder_vocab_size: Some(50265),
|
decoder_vocab_size: Some(50265),
|
||||||
|
use_learned_position_embeddings: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -221,8 +221,8 @@ impl TrOCRDecoderLayer {
|
|||||||
let encoder_attn = TrOCRAttention::load(
|
let encoder_attn = TrOCRAttention::load(
|
||||||
vb.pp("encoder_attn"),
|
vb.pp("encoder_attn"),
|
||||||
cfg,
|
cfg,
|
||||||
Some(cfg.hidden_size),
|
Some(cfg.cross_attention_hidden_size),
|
||||||
Some(cfg.hidden_size),
|
Some(cfg.cross_attention_hidden_size),
|
||||||
)?;
|
)?;
|
||||||
let encoder_attn_layer_norm =
|
let encoder_attn_layer_norm =
|
||||||
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||||
@ -294,6 +294,9 @@ impl TrOCRDecoder {
|
|||||||
let vb = vb.pp("decoder.model.decoder");
|
let vb = vb.pp("decoder.model.decoder");
|
||||||
|
|
||||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||||
|
if !cfg.use_learned_position_embeddings {
|
||||||
|
candle::bail!("only models with use_learned_position_embeddings=true are supported")
|
||||||
|
}
|
||||||
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
||||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||||
let vb_l = vb.pp("layers");
|
let vb_l = vb.pp("layers");
|
||||||
|
@ -3,7 +3,7 @@ use candle::{IndexOp, Module, Result, Tensor, D};
|
|||||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub num_hidden_layers: usize,
|
pub num_hidden_layers: usize,
|
||||||
|
Reference in New Issue
Block a user