diff --git a/candle-examples/examples/trocr/assets/trocr.png b/candle-examples/examples/trocr/assets/trocr.png new file mode 100644 index 00000000..06886aab Binary files /dev/null and b/candle-examples/examples/trocr/assets/trocr.png differ diff --git a/candle-examples/examples/trocr/image_processor.rs b/candle-examples/examples/trocr/image_processor.rs new file mode 100644 index 00000000..531caa56 --- /dev/null +++ b/candle-examples/examples/trocr/image_processor.rs @@ -0,0 +1,154 @@ +use image::{DynamicImage, ImageBuffer}; +use serde::Deserialize; +use std::collections::HashMap; + +use candle::{DType, Device, Result, Tensor}; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct ProcessorConfig { + do_resize: bool, + height: u32, + width: u32, + do_rescale: bool, + do_normalize: bool, + image_mean: Vec, + image_std: Vec, +} + +impl Default for ProcessorConfig { + fn default() -> Self { + Self { + do_resize: true, + height: 384, + width: 384, + do_rescale: true, + do_normalize: true, + image_mean: vec![0.5, 0.5, 0.5], + image_std: vec![0.5, 0.5, 0.5], + } + } +} + +pub struct ViTImageProcessor { + do_resize: bool, + height: u32, + width: u32, + do_normalize: bool, + image_mean: Vec, + image_std: Vec, +} + +impl ViTImageProcessor { + pub fn new(config: &ProcessorConfig) -> Self { + Self { + do_resize: config.do_resize, + height: config.height, + width: config.width, + do_normalize: config.do_normalize, + image_mean: config.image_mean.clone(), + image_std: config.image_std.clone(), + } + } + + pub fn preprocess(&self, images: Vec<&str>) -> Result { + let height = self.height as usize; + let width = self.width as usize; + let channels = 3; + + let images = self.load_images(images)?; + + let resized_images: Vec = if self.do_resize { + images + .iter() + .map(|image| self.resize(image.clone(), None).unwrap()) + .collect() + } else { + images + }; + + let normalized_images: Vec = if self.do_normalize { + resized_images + .iter() + .map(|image| self.normalize(image.clone(), None, None).unwrap()) + .collect() + } else { + let resized_images: Vec, Vec>> = + resized_images.iter().map(|image| image.to_rgb8()).collect(); + let data = resized_images + .into_iter() + .map(|image| image.into_raw()) + .collect::>>(); + + data.iter() + .map(|image| { + Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu) + .unwrap() + .permute((2, 0, 1)) + .unwrap() + }) + .collect::>() + }; + + Tensor::stack(&normalized_images, 0) + } + + fn resize( + &self, + image: image::DynamicImage, + size: Option>, + ) -> Result { + let (height, width) = match &size { + Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()), + None => (&self.height, &self.width), + }; + + let resized_image = + image.resize_exact(*width, *height, image::imageops::FilterType::Triangle); + + Ok(resized_image) + } + + fn normalize( + &self, + image: image::DynamicImage, + mean: Option>, + std: Option>, + ) -> Result { + let mean = match mean { + Some(mean) => mean, + None => self.image_mean.clone(), + }; + + let std = match std { + Some(std) => std, + None => self.image_std.clone(), + }; + + let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?; + let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?; + + let image = image.to_rgb8(); + let data = image.into_raw(); + + let height = self.height as usize; + let width = self.width as usize; + let channels = 3; + + let data = + Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?; + + (data.to_dtype(DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) + } + + pub fn load_images(&self, image_path: Vec<&str>) -> Result> { + let mut images: Vec = Vec::new(); + for path in image_path { + let img = image::io::Reader::open(path)?.decode().unwrap(); + images.push(img); + } + + Ok(images) + } +} diff --git a/candle-examples/examples/trocr/main.rs b/candle-examples/examples/trocr/main.rs new file mode 100644 index 00000000..e93d6b2f --- /dev/null +++ b/candle-examples/examples/trocr/main.rs @@ -0,0 +1,132 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::{Parser, ValueEnum}; + +use candle::{DType, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::models::trocr; + +use tokenizers::Tokenizer; +mod image_processor; + +#[derive(Clone, Debug, Copy, ValueEnum)] +enum Which { + Base, + Large, +} + +#[derive(Parser, Debug)] +struct Args { + #[arg(long)] + model: Option, + + /// Choose the variant of the model to run. + #[arg(long, default_value = "base")] + which: Which, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Text to be translated + #[arg(long)] + image: String, +} + +pub fn main() -> anyhow::Result<()> { + use hf_hub::api::sync::Api; + let args = Args::parse(); + + let tokenizer_dec = { + let tokenizer = Api::new()? + .model(String::from("ToluClassics/candle-trocr-tokenizer")) + .get("tokenizer.json")?; + + Tokenizer::from_file(&tokenizer).map_err(E::msg)? + }; + + let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec); + + let device = candle_examples::device(args.cpu)?; + + let vb = { + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => match args.which { + Which::Base => Api::new()? + .repo(hf_hub::Repo::with_revision( + "microsoft/trocr-base-handwritten".to_string(), + hf_hub::RepoType::Model, + "refs/pr/3".to_string(), + )) + .get("model.safetensors")?, + Which::Large => Api::new()? + .repo(hf_hub::Repo::with_revision( + "microsoft/trocr-large-handwritten".to_string(), + hf_hub::RepoType::Model, + "refs/pr/6".to_string(), + )) + .get("model.safetensors")?, + }, + }; + println!("model: {:?}", model); + unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? } + }; + + let encoder_config = match args.which { + Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(), + Which::Large => { + candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten() + } + }; + + let decoder_config = trocr::TrOCRConfig::default(); + let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?; + + let config = image_processor::ProcessorConfig::default(); + let processor = image_processor::ViTImageProcessor::new(&config); + + let image = vec![args.image.as_str()]; + let image = processor.preprocess(image)?; + + let encoder_xs = model.encoder().forward(&image)?; + + let mut logits_processor = + candle_transformers::generation::LogitsProcessor::new(1337, None, None); + + let mut token_ids: Vec = vec![decoder_config.decoder_start_token_id]; + for index in 0..1000 { + let context_size = if index >= 1 { 1 } else { token_ids.len() }; + let start_pos = token_ids.len().saturating_sub(context_size); + let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; + + let logits = model.decode(&input_ids, &encoder_xs, start_pos)?; + + let logits = logits.squeeze(0)?; + let logits = logits.get(logits.dim(0)? - 1)?; + let token = logits_processor.sample(&logits)?; + token_ids.push(token); + + if let Some(t) = tokenizer_dec.next_token(token)? { + use std::io::Write; + print!("{t}"); + std::io::stdout().flush()?; + } + if token == decoder_config.eos_token_id { + break; + } + } + + if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + println!(); + + Ok(()) +} diff --git a/candle-examples/examples/trocr/readme.md b/candle-examples/examples/trocr/readme.md new file mode 100644 index 00000000..329940f8 --- /dev/null +++ b/candle-examples/examples/trocr/readme.md @@ -0,0 +1,16 @@ +# candle-trocr + +`TrOCR` is a transformer OCR Model. In this example it is used to +transcribe image text. See the associated [model +card](https://huggingface.co/microsoft/trocr-base-printed) for details on +the model itself. + +## Running an example + +```bash +cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png +``` + +``` + industry , Mr. Brown commented icily . " Let us have a +``` diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 370b9108..3c025660 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -29,6 +29,7 @@ pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; pub mod t5; +pub mod trocr; pub mod vgg; pub mod vit; pub mod whisper; diff --git a/candle-transformers/src/models/trocr.rs b/candle-transformers/src/models/trocr.rs new file mode 100644 index 00000000..785b06ca --- /dev/null +++ b/candle-transformers/src/models/trocr.rs @@ -0,0 +1,434 @@ +use crate::models::vit::{Config, Embeddings, Encoder}; +use candle::{Result, Tensor}; +use candle_nn::{ + embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder, +}; +use serde::Deserialize; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct TrOCRConfig { + pub vocab_size: usize, + pub d_model: usize, + pub hidden_size: usize, + pub decoder_layers: usize, + pub decoder_attention_heads: usize, + pub decoder_ffn_dim: usize, + pub activation_function: candle_nn::Activation, + pub max_position_embeddings: usize, + pub dropout: f64, + pub attention_dropout: f64, + pub activation_dropout: f64, + pub decoder_start_token_id: u32, + pub init_std: f64, + pub decoder_layerdrop: f64, + pub use_cache: bool, + pub scale_embedding: bool, + pub use_learned_position_embeddings: bool, + pub layernorm_embedding: bool, + pub pad_token_id: usize, + pub bos_token_id: usize, + pub eos_token_id: u32, + pub num_attention_heads: usize, + pub decoder_vocab_size: Option, +} + +impl Default for TrOCRConfig { + fn default() -> Self { + Self { + vocab_size: 50265, + d_model: 1024, + hidden_size: 768, + decoder_layers: 12, + decoder_attention_heads: 16, + decoder_ffn_dim: 4096, + activation_function: candle_nn::Activation::Gelu, + max_position_embeddings: 512, + dropout: 0.1, + attention_dropout: 0.0, + activation_dropout: 0.0, + decoder_start_token_id: 2, + init_std: 0.02, + decoder_layerdrop: 0.0, + use_cache: true, + scale_embedding: false, + use_learned_position_embeddings: true, + layernorm_embedding: true, + pad_token_id: 1, + bos_token_id: 0, + eos_token_id: 2, + num_attention_heads: 12, + decoder_vocab_size: Some(50265), + } + } +} + +#[derive(Debug, Clone)] +struct TrOCRLearnedPositionalEmbedding { + offset: usize, + weights: Embedding, +} + +impl TrOCRLearnedPositionalEmbedding { + fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result { + let offset: usize = 2; + let num_embeddings = cfg.max_position_embeddings; + let embedding_dim = cfg.d_model; + let weights = embedding(num_embeddings + offset, embedding_dim, vb)?; + + Ok(Self { offset, weights }) + } + + fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result { + let (b_sz, seq_len) = input_ids.dims2()?; + + let mut positions = Tensor::arange( + past_key_values_length, + seq_len as u32 + past_key_values_length, + input_ids.device(), + )? + .expand((b_sz, seq_len))?; + + positions = + positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?; + self.weights.forward(&positions) + } +} + +#[derive(Debug, Clone)] +struct TrOCRAttention { + head_dim: usize, + num_heads: usize, + is_decoder: bool, + scaling: f64, + k_proj: Linear, + v_proj: Linear, + q_proj: Linear, + out_proj: Linear, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl TrOCRAttention { + fn load( + vb: VarBuilder, + cfg: &TrOCRConfig, + kdim: Option, + vdim: Option, + ) -> Result { + let embed_dim = cfg.d_model; + let num_heads = cfg.decoder_attention_heads; + let head_dim = embed_dim / num_heads; + let kdim = kdim.unwrap_or(embed_dim); + let vdim = vdim.unwrap_or(embed_dim); + + let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?; + let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?; + + let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?; + Ok(Self { + head_dim, + num_heads, + is_decoder: true, + scaling: 1. / (head_dim as f64).sqrt(), + k_proj, + v_proj, + q_proj, + out_proj, + kv_cache: None, + }) + } + + fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result { + tensor + .reshape((bsz, (), self.num_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward( + &mut self, + xs: &Tensor, + kv_states: Option<&Tensor>, + attn_mask: Option<&Tensor>, + ) -> Result { + let (b_sz, tgt_len, _) = xs.dims3()?; + let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; + let (key_states, value_states) = match kv_states { + None => { + let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?; + let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?; + if self.is_decoder { + let kv_states = match &self.kv_cache { + None => (key_states, value_states), + Some((p_key_states, p_value_states)) => { + let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?; + let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some(kv_states.clone()); + kv_states + } else { + (key_states, value_states) + } + } + Some(kv_states) => { + let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?; + let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?; + (key_states, value_states) + } + }; + let proj_shape = (b_sz * self.num_heads, (), self.head_dim); + let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?; + let key_states = key_states.reshape(proj_shape)?; + let value_states = value_states.reshape(proj_shape)?; + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + let attn_weights = match attn_mask { + None => attn_weights, + Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?, + }; + let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_probs.matmul(&value_states)?; + attn_output + .reshape((b_sz, self.num_heads, tgt_len, self.head_dim))? + .transpose(1, 2)? + .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))? + .apply(&self.out_proj) + } +} + +#[derive(Debug, Clone)] +struct TrOCRDecoderLayer { + self_attn: TrOCRAttention, + activation_fn: candle_nn::Activation, + self_attn_layer_norm: LayerNorm, + encoder_attn: TrOCRAttention, + encoder_attn_layer_norm: LayerNorm, + fc1: Linear, + fc2: Linear, + final_layer_norm: LayerNorm, +} + +impl TrOCRDecoderLayer { + fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result { + let embed_dim = cfg.d_model; + let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?; + let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?; + let encoder_attn = TrOCRAttention::load( + vb.pp("encoder_attn"), + cfg, + Some(cfg.hidden_size), + Some(cfg.hidden_size), + )?; + let encoder_attn_layer_norm = + layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?; + let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?; + let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?; + let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?; + let activation_fn = candle_nn::Activation::Gelu; + + Ok(Self { + self_attn, + activation_fn, + self_attn_layer_norm, + encoder_attn, + encoder_attn_layer_norm, + fc1, + fc2, + final_layer_norm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: &Tensor, + encoder_hidden_states: Option<&Tensor>, + ) -> Result { + let residual = xs.clone(); + let xs = self.self_attn.forward(xs, None, Some(attention_mask))?; + let xs = (xs + residual)?; + let mut xs = self.self_attn_layer_norm.forward(&xs)?; + + if let Some(encoder_hidden_states) = &encoder_hidden_states { + let residual = xs.clone(); + let encoder_attention_mask = attention_mask.clone(); // TODO + xs = self.encoder_attn.forward( + &xs, + Some(encoder_hidden_states), + Some(&encoder_attention_mask), + )?; + xs = (xs + residual)?; + xs = self.encoder_attn_layer_norm.forward(&xs)? + } + + let residual = xs.clone(); + let xs = self.fc1.forward(&xs)?; + let xs = self.activation_fn.forward(&xs)?; + let xs = self.fc2.forward(&xs)?; + let xs = (xs + residual)?; + let xs = self.final_layer_norm.forward(&xs)?; + + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct TrOCRDecoder { + layers: Vec, + embed_scale: Option, + embed_tokens: Embedding, + embed_positions: TrOCRLearnedPositionalEmbedding, +} + +impl TrOCRDecoder { + fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result { + let vb = vb.pp("decoder.model.decoder"); + + let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?; + let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?; + let mut layers = Vec::with_capacity(cfg.decoder_layers); + let vb_l = vb.pp("layers"); + for idx in 0..cfg.decoder_layers { + let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?; + layers.push(layer) + } + let embed_scale = if cfg.scale_embedding { + Some((cfg.d_model as f64).sqrt()) + } else { + None + }; + + Ok(Self { + layers, + embed_scale, + embed_tokens, + embed_positions, + }) + } + + pub fn forward( + &mut self, + xs: &Tensor, + encoder_xs: Option<&Tensor>, + past_kv_len: usize, + attn_mask: &Tensor, + ) -> Result { + let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?; + let xs = xs.apply(&self.embed_tokens)?; + + let xs = match self.embed_scale { + None => xs, + Some(scale) => (xs * scale)?, + }; + + let mut xs = xs.broadcast_add(&embed_pos)?; + + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, encoder_xs)?; + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct TrOCREncoder { + embeddings: Embeddings, + encoder: Encoder, + layernorm: LayerNorm, +} + +impl TrOCREncoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_v = vb.pp("encoder"); + + let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?; + + let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?; + let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?; + + Ok(Self { + embeddings, + encoder, + layernorm, + }) + } + + pub fn forward(&self, xs: &Tensor) -> Result { + let embedding_output = self.embeddings.forward(xs, None, false)?; + let encoder_outputs = self.encoder.forward(&embedding_output)?; + + self.layernorm.forward(&encoder_outputs) + } +} + +#[derive(Debug, Clone)] +pub struct TrOCRForCausalLM { + decoder: TrOCRDecoder, + output_projection: Linear, +} + +impl TrOCRForCausalLM { + pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result { + let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?; + let output_projection = + candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None); + Ok(Self { + decoder, + output_projection, + }) + } + + pub fn forward( + &mut self, + xs: &Tensor, + encoder_xs: Option<&Tensor>, + past_kv_len: usize, + attn_mask: &Tensor, + ) -> Result { + let xs = self + .decoder + .forward(xs, encoder_xs, past_kv_len, attn_mask)?; + let xs = xs.apply(&self.output_projection)?; + + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct TrOCRModel { + encoder: TrOCREncoder, + decoder: TrOCRForCausalLM, +} + +impl TrOCRModel { + pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result { + let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?; + let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?; + Ok(Self { encoder, decoder }) + } + + pub fn encoder(&mut self) -> &mut TrOCREncoder { + &mut self.encoder + } + + pub fn decoder(&mut self) -> &mut TrOCRForCausalLM { + &mut self.decoder + } + + pub fn decode( + &mut self, + xs: &Tensor, + encoder_xs: &Tensor, + past_kv_len: usize, + ) -> Result { + let seq_len = xs.dim(1)?; + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?; + + self.decoder + .forward(xs, Some(encoder_xs), past_kv_len, &mask) + } +} diff --git a/candle-transformers/src/models/vit.rs b/candle-transformers/src/models/vit.rs index e2218c54..962528c1 100644 --- a/candle-transformers/src/models/vit.rs +++ b/candle-transformers/src/models/vit.rs @@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder}; // https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py #[derive(Debug, Clone)] pub struct Config { - hidden_size: usize, - num_hidden_layers: usize, - num_attention_heads: usize, - intermediate_size: usize, - hidden_act: candle_nn::Activation, - layer_norm_eps: f64, - image_size: usize, - patch_size: usize, - num_channels: usize, - qkv_bias: bool, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, + pub hidden_act: candle_nn::Activation, + pub layer_norm_eps: f64, + pub image_size: usize, + pub patch_size: usize, + pub num_channels: usize, + pub qkv_bias: bool, } impl Config { @@ -34,6 +34,21 @@ impl Config { qkv_bias: true, } } + + pub fn microsoft_trocr_base_handwritten() -> Self { + Self { + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: candle_nn::Activation::Gelu, + layer_norm_eps: 1e-12, + image_size: 384, + patch_size: 16, + num_channels: 3, + qkv_bias: false, + } + } } #[derive(Debug, Clone)] @@ -76,7 +91,7 @@ impl Module for PatchEmbeddings { } #[derive(Debug, Clone)] -struct Embeddings { +pub struct Embeddings { cls_token: Tensor, mask_token: Option, patch_embeddings: PatchEmbeddings, @@ -85,7 +100,7 @@ struct Embeddings { } impl Embeddings { - fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result { + pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result { let hidden_size = cfg.hidden_size; let cls_token = vb.get((1, 1, hidden_size), "cls_token")?; let mask_token = if use_mask_token { @@ -115,7 +130,7 @@ impl Embeddings { todo!() } - fn forward( + pub fn forward( &self, pixel_values: &Tensor, bool_masked_pos: Option<&Tensor>, @@ -324,12 +339,12 @@ impl Module for Layer { } #[derive(Debug, Clone)] -struct Encoder { +pub struct Encoder { layers: Vec, } impl Encoder { - fn new(cfg: &Config, vb: VarBuilder) -> Result { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let vb = vb.pp("layer"); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); for i in 0..cfg.num_hidden_layers {