mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Use the repo config for trocr rather than hardcoding it + small tweaks. (#1689)
* Use the repo config for trocr rather than hardcoding it + small tweaks. * Add support for the printed models. * Fail with an appropriate error message on missing position embeddings.
This commit is contained in:
@ -10,15 +10,36 @@ use clap::{Parser, ValueEnum};
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::trocr;
|
||||
use candle_transformers::models::{trocr, vit};
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
mod image_processor;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Large,
|
||||
#[value(name = "base")]
|
||||
BaseHandwritten,
|
||||
#[value(name = "large")]
|
||||
LargeHandwritten,
|
||||
BasePrinted,
|
||||
LargePrinted,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn repo_and_branch_name(&self) -> (&str, &str) {
|
||||
match self {
|
||||
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
|
||||
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
|
||||
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
|
||||
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
struct Config {
|
||||
encoder: vit::Config,
|
||||
decoder: trocr::TrOCRConfig,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -34,63 +55,64 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Text to be translated
|
||||
/// The image file to be processed.
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Tokenization config.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = Api::new()?
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?;
|
||||
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
let mut tokenizer_dec = {
|
||||
let tokenizer_file = match args.tokenizer {
|
||||
None => api
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?,
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
|
||||
TokenOutputStream::new(tokenizer)
|
||||
};
|
||||
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-base-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/3".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Large => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-large-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/6".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
None => {
|
||||
let (repo, branch) = args.which.repo_and_branch_name();
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
branch.to_string(),
|
||||
))
|
||||
.get("model.safetensors")?
|
||||
}
|
||||
};
|
||||
println!("model: {:?}", model);
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||
};
|
||||
|
||||
let encoder_config = match args.which {
|
||||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||
Which::Large => {
|
||||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||
}
|
||||
let (encoder_config, decoder_config) = {
|
||||
let (repo, branch) = args.which.repo_and_branch_name();
|
||||
let config_filename = api
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
branch.to_string(),
|
||||
))
|
||||
.get("config.json")?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
(config.encoder, config.decoder)
|
||||
};
|
||||
|
||||
let decoder_config = trocr::TrOCRConfig::default();
|
||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||
|
||||
let config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||
let processor_config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
||||
|
||||
let image = vec![args.image.as_str()];
|
||||
let image = processor.preprocess(image)?;
|
||||
|
Reference in New Issue
Block a user