mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add support for TrOCR Model (#1303)
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting * add trocr model * fix formatting * commit the actual model lol * more formatting * remove tokenizer config
This commit is contained in:
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
154
candle-examples/examples/trocr/image_processor.rs
Normal file
154
candle-examples/examples/trocr/image_processor.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use image::{DynamicImage, ImageBuffer};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct ProcessorConfig {
|
||||
do_resize: bool,
|
||||
height: u32,
|
||||
width: u32,
|
||||
do_rescale: bool,
|
||||
do_normalize: bool,
|
||||
image_mean: Vec<f32>,
|
||||
image_std: Vec<f32>,
|
||||
}
|
||||
|
||||
impl Default for ProcessorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
do_resize: true,
|
||||
height: 384,
|
||||
width: 384,
|
||||
do_rescale: true,
|
||||
do_normalize: true,
|
||||
image_mean: vec![0.5, 0.5, 0.5],
|
||||
image_std: vec![0.5, 0.5, 0.5],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ViTImageProcessor {
|
||||
do_resize: bool,
|
||||
height: u32,
|
||||
width: u32,
|
||||
do_normalize: bool,
|
||||
image_mean: Vec<f32>,
|
||||
image_std: Vec<f32>,
|
||||
}
|
||||
|
||||
impl ViTImageProcessor {
|
||||
pub fn new(config: &ProcessorConfig) -> Self {
|
||||
Self {
|
||||
do_resize: config.do_resize,
|
||||
height: config.height,
|
||||
width: config.width,
|
||||
do_normalize: config.do_normalize,
|
||||
image_mean: config.image_mean.clone(),
|
||||
image_std: config.image_std.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
|
||||
let height = self.height as usize;
|
||||
let width = self.width as usize;
|
||||
let channels = 3;
|
||||
|
||||
let images = self.load_images(images)?;
|
||||
|
||||
let resized_images: Vec<DynamicImage> = if self.do_resize {
|
||||
images
|
||||
.iter()
|
||||
.map(|image| self.resize(image.clone(), None).unwrap())
|
||||
.collect()
|
||||
} else {
|
||||
images
|
||||
};
|
||||
|
||||
let normalized_images: Vec<Tensor> = if self.do_normalize {
|
||||
resized_images
|
||||
.iter()
|
||||
.map(|image| self.normalize(image.clone(), None, None).unwrap())
|
||||
.collect()
|
||||
} else {
|
||||
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
|
||||
resized_images.iter().map(|image| image.to_rgb8()).collect();
|
||||
let data = resized_images
|
||||
.into_iter()
|
||||
.map(|image| image.into_raw())
|
||||
.collect::<Vec<Vec<u8>>>();
|
||||
|
||||
data.iter()
|
||||
.map(|image| {
|
||||
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
|
||||
.unwrap()
|
||||
.permute((2, 0, 1))
|
||||
.unwrap()
|
||||
})
|
||||
.collect::<Vec<Tensor>>()
|
||||
};
|
||||
|
||||
Tensor::stack(&normalized_images, 0)
|
||||
}
|
||||
|
||||
fn resize(
|
||||
&self,
|
||||
image: image::DynamicImage,
|
||||
size: Option<HashMap<String, u32>>,
|
||||
) -> Result<image::DynamicImage> {
|
||||
let (height, width) = match &size {
|
||||
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
|
||||
None => (&self.height, &self.width),
|
||||
};
|
||||
|
||||
let resized_image =
|
||||
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
|
||||
|
||||
Ok(resized_image)
|
||||
}
|
||||
|
||||
fn normalize(
|
||||
&self,
|
||||
image: image::DynamicImage,
|
||||
mean: Option<Vec<f32>>,
|
||||
std: Option<Vec<f32>>,
|
||||
) -> Result<Tensor> {
|
||||
let mean = match mean {
|
||||
Some(mean) => mean,
|
||||
None => self.image_mean.clone(),
|
||||
};
|
||||
|
||||
let std = match std {
|
||||
Some(std) => std,
|
||||
None => self.image_std.clone(),
|
||||
};
|
||||
|
||||
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
|
||||
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
|
||||
|
||||
let image = image.to_rgb8();
|
||||
let data = image.into_raw();
|
||||
|
||||
let height = self.height as usize;
|
||||
let width = self.width as usize;
|
||||
let channels = 3;
|
||||
|
||||
let data =
|
||||
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
|
||||
(data.to_dtype(DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
|
||||
let mut images: Vec<image::DynamicImage> = Vec::new();
|
||||
for path in image_path {
|
||||
let img = image::io::Reader::open(path)?.decode().unwrap();
|
||||
images.push(img);
|
||||
}
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
}
|
132
candle-examples/examples/trocr/main.rs
Normal file
132
candle-examples/examples/trocr/main.rs
Normal file
@ -0,0 +1,132 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::trocr;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
mod image_processor;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Large,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// Choose the variant of the model to run.
|
||||
#[arg(long, default_value = "base")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = Api::new()?
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?;
|
||||
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-base-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/3".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Large => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-large-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/6".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
};
|
||||
println!("model: {:?}", model);
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||
};
|
||||
|
||||
let encoder_config = match args.which {
|
||||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||
Which::Large => {
|
||||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||
}
|
||||
};
|
||||
|
||||
let decoder_config = trocr::TrOCRConfig::default();
|
||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||
|
||||
let config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||
|
||||
let image = vec![args.image.as_str()];
|
||||
let image = processor.preprocess(image)?;
|
||||
|
||||
let encoder_xs = model.encoder().forward(&image)?;
|
||||
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
|
||||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
token_ids.push(token);
|
||||
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if token == decoder_config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
16
candle-examples/examples/trocr/readme.md
Normal file
16
candle-examples/examples/trocr/readme.md
Normal file
@ -0,0 +1,16 @@
|
||||
# candle-trocr
|
||||
|
||||
`TrOCR` is a transformer OCR Model. In this example it is used to
|
||||
transcribe image text. See the associated [model
|
||||
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
||||
the model itself.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
|
||||
```
|
||||
|
||||
```
|
||||
<s> industry , Mr. Brown commented icily . " Let us have a</s>
|
||||
```
|
Reference in New Issue
Block a user