Use the repo config for trocr rather than hardcoding it + small tweaks. (#1689)

* Use the repo config for trocr rather than hardcoding it + small tweaks.

* Add support for the printed models.

* Fail with an appropriate error message on missing position embeddings.
This commit is contained in:
Laurent Mazare
2024-02-10 13:15:03 +01:00
committed by GitHub
parent 67589791d2
commit 42ce593ec6
3 changed files with 77 additions and 52 deletions

View File

@ -10,15 +10,36 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::trocr;
use candle_transformers::models::{trocr, vit};
use tokenizers::Tokenizer;
mod image_processor;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Large,
#[value(name = "base")]
BaseHandwritten,
#[value(name = "large")]
LargeHandwritten,
BasePrinted,
LargePrinted,
}
impl Which {
fn repo_and_branch_name(&self) -> (&str, &str) {
match self {
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
}
}
}
#[derive(Debug, Clone, serde::Deserialize)]
struct Config {
encoder: vit::Config,
decoder: trocr::TrOCRConfig,
}
#[derive(Parser, Debug)]
@ -34,63 +55,64 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Text to be translated
/// The image file to be processed.
#[arg(long)]
image: String,
/// Tokenization config.
#[arg(long)]
tokenizer: Option<String>,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let api = hf_hub::api::sync::Api::new()?;
let tokenizer_dec = {
let tokenizer = Api::new()?
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?;
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
let mut tokenizer_dec = {
let tokenizer_file = match args.tokenizer {
None => api
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?,
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
};
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
TokenOutputStream::new(tokenizer)
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-base-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
))
.get("model.safetensors")?,
Which::Large => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-large-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/6".to_string(),
))
.get("model.safetensors")?,
},
None => {
let (repo, branch) = args.which.repo_and_branch_name();
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
branch.to_string(),
))
.get("model.safetensors")?
}
};
println!("model: {:?}", model);
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
};
let encoder_config = match args.which {
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
Which::Large => {
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
}
let (encoder_config, decoder_config) = {
let (repo, branch) = args.which.repo_and_branch_name();
let config_filename = api
.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
branch.to_string(),
))
.get("config.json")?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
(config.encoder, config.decoder)
};
let decoder_config = trocr::TrOCRConfig::default();
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
let config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&config);
let processor_config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&processor_config);
let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;