mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add support for TrOCR Model (#1303)
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting * add trocr model * fix formatting * commit the actual model lol * more formatting * remove tokenizer config
This commit is contained in:
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
154
candle-examples/examples/trocr/image_processor.rs
Normal file
154
candle-examples/examples/trocr/image_processor.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use image::{DynamicImage, ImageBuffer};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct ProcessorConfig {
|
||||
do_resize: bool,
|
||||
height: u32,
|
||||
width: u32,
|
||||
do_rescale: bool,
|
||||
do_normalize: bool,
|
||||
image_mean: Vec<f32>,
|
||||
image_std: Vec<f32>,
|
||||
}
|
||||
|
||||
impl Default for ProcessorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
do_resize: true,
|
||||
height: 384,
|
||||
width: 384,
|
||||
do_rescale: true,
|
||||
do_normalize: true,
|
||||
image_mean: vec![0.5, 0.5, 0.5],
|
||||
image_std: vec![0.5, 0.5, 0.5],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ViTImageProcessor {
|
||||
do_resize: bool,
|
||||
height: u32,
|
||||
width: u32,
|
||||
do_normalize: bool,
|
||||
image_mean: Vec<f32>,
|
||||
image_std: Vec<f32>,
|
||||
}
|
||||
|
||||
impl ViTImageProcessor {
|
||||
pub fn new(config: &ProcessorConfig) -> Self {
|
||||
Self {
|
||||
do_resize: config.do_resize,
|
||||
height: config.height,
|
||||
width: config.width,
|
||||
do_normalize: config.do_normalize,
|
||||
image_mean: config.image_mean.clone(),
|
||||
image_std: config.image_std.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
|
||||
let height = self.height as usize;
|
||||
let width = self.width as usize;
|
||||
let channels = 3;
|
||||
|
||||
let images = self.load_images(images)?;
|
||||
|
||||
let resized_images: Vec<DynamicImage> = if self.do_resize {
|
||||
images
|
||||
.iter()
|
||||
.map(|image| self.resize(image.clone(), None).unwrap())
|
||||
.collect()
|
||||
} else {
|
||||
images
|
||||
};
|
||||
|
||||
let normalized_images: Vec<Tensor> = if self.do_normalize {
|
||||
resized_images
|
||||
.iter()
|
||||
.map(|image| self.normalize(image.clone(), None, None).unwrap())
|
||||
.collect()
|
||||
} else {
|
||||
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
|
||||
resized_images.iter().map(|image| image.to_rgb8()).collect();
|
||||
let data = resized_images
|
||||
.into_iter()
|
||||
.map(|image| image.into_raw())
|
||||
.collect::<Vec<Vec<u8>>>();
|
||||
|
||||
data.iter()
|
||||
.map(|image| {
|
||||
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
|
||||
.unwrap()
|
||||
.permute((2, 0, 1))
|
||||
.unwrap()
|
||||
})
|
||||
.collect::<Vec<Tensor>>()
|
||||
};
|
||||
|
||||
Tensor::stack(&normalized_images, 0)
|
||||
}
|
||||
|
||||
fn resize(
|
||||
&self,
|
||||
image: image::DynamicImage,
|
||||
size: Option<HashMap<String, u32>>,
|
||||
) -> Result<image::DynamicImage> {
|
||||
let (height, width) = match &size {
|
||||
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
|
||||
None => (&self.height, &self.width),
|
||||
};
|
||||
|
||||
let resized_image =
|
||||
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
|
||||
|
||||
Ok(resized_image)
|
||||
}
|
||||
|
||||
fn normalize(
|
||||
&self,
|
||||
image: image::DynamicImage,
|
||||
mean: Option<Vec<f32>>,
|
||||
std: Option<Vec<f32>>,
|
||||
) -> Result<Tensor> {
|
||||
let mean = match mean {
|
||||
Some(mean) => mean,
|
||||
None => self.image_mean.clone(),
|
||||
};
|
||||
|
||||
let std = match std {
|
||||
Some(std) => std,
|
||||
None => self.image_std.clone(),
|
||||
};
|
||||
|
||||
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
|
||||
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
|
||||
|
||||
let image = image.to_rgb8();
|
||||
let data = image.into_raw();
|
||||
|
||||
let height = self.height as usize;
|
||||
let width = self.width as usize;
|
||||
let channels = 3;
|
||||
|
||||
let data =
|
||||
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
|
||||
(data.to_dtype(DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
|
||||
let mut images: Vec<image::DynamicImage> = Vec::new();
|
||||
for path in image_path {
|
||||
let img = image::io::Reader::open(path)?.decode().unwrap();
|
||||
images.push(img);
|
||||
}
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
}
|
132
candle-examples/examples/trocr/main.rs
Normal file
132
candle-examples/examples/trocr/main.rs
Normal file
@ -0,0 +1,132 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::trocr;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
mod image_processor;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Large,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// Choose the variant of the model to run.
|
||||
#[arg(long, default_value = "base")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = Api::new()?
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?;
|
||||
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-base-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/3".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Large => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-large-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/6".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
};
|
||||
println!("model: {:?}", model);
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||
};
|
||||
|
||||
let encoder_config = match args.which {
|
||||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||
Which::Large => {
|
||||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||
}
|
||||
};
|
||||
|
||||
let decoder_config = trocr::TrOCRConfig::default();
|
||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||
|
||||
let config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||
|
||||
let image = vec![args.image.as_str()];
|
||||
let image = processor.preprocess(image)?;
|
||||
|
||||
let encoder_xs = model.encoder().forward(&image)?;
|
||||
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
|
||||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
token_ids.push(token);
|
||||
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if token == decoder_config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
16
candle-examples/examples/trocr/readme.md
Normal file
16
candle-examples/examples/trocr/readme.md
Normal file
@ -0,0 +1,16 @@
|
||||
# candle-trocr
|
||||
|
||||
`TrOCR` is a transformer OCR Model. In this example it is used to
|
||||
transcribe image text. See the associated [model
|
||||
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
||||
the model itself.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
|
||||
```
|
||||
|
||||
```
|
||||
<s> industry , Mr. Brown commented icily . " Let us have a</s>
|
||||
```
|
@ -29,6 +29,7 @@ pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
pub mod t5;
|
||||
pub mod trocr;
|
||||
pub mod vgg;
|
||||
pub mod vit;
|
||||
pub mod whisper;
|
||||
|
434
candle-transformers/src/models/trocr.rs
Normal file
434
candle-transformers/src/models/trocr.rs
Normal file
@ -0,0 +1,434 @@
|
||||
use crate::models::vit::{Config, Embeddings, Encoder};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct TrOCRConfig {
|
||||
pub vocab_size: usize,
|
||||
pub d_model: usize,
|
||||
pub hidden_size: usize,
|
||||
pub decoder_layers: usize,
|
||||
pub decoder_attention_heads: usize,
|
||||
pub decoder_ffn_dim: usize,
|
||||
pub activation_function: candle_nn::Activation,
|
||||
pub max_position_embeddings: usize,
|
||||
pub dropout: f64,
|
||||
pub attention_dropout: f64,
|
||||
pub activation_dropout: f64,
|
||||
pub decoder_start_token_id: u32,
|
||||
pub init_std: f64,
|
||||
pub decoder_layerdrop: f64,
|
||||
pub use_cache: bool,
|
||||
pub scale_embedding: bool,
|
||||
pub use_learned_position_embeddings: bool,
|
||||
pub layernorm_embedding: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: u32,
|
||||
pub num_attention_heads: usize,
|
||||
pub decoder_vocab_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for TrOCRConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 50265,
|
||||
d_model: 1024,
|
||||
hidden_size: 768,
|
||||
decoder_layers: 12,
|
||||
decoder_attention_heads: 16,
|
||||
decoder_ffn_dim: 4096,
|
||||
activation_function: candle_nn::Activation::Gelu,
|
||||
max_position_embeddings: 512,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
activation_dropout: 0.0,
|
||||
decoder_start_token_id: 2,
|
||||
init_std: 0.02,
|
||||
decoder_layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
scale_embedding: false,
|
||||
use_learned_position_embeddings: true,
|
||||
layernorm_embedding: true,
|
||||
pad_token_id: 1,
|
||||
bos_token_id: 0,
|
||||
eos_token_id: 2,
|
||||
num_attention_heads: 12,
|
||||
decoder_vocab_size: Some(50265),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TrOCRLearnedPositionalEmbedding {
|
||||
offset: usize,
|
||||
weights: Embedding,
|
||||
}
|
||||
|
||||
impl TrOCRLearnedPositionalEmbedding {
|
||||
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
||||
let offset: usize = 2;
|
||||
let num_embeddings = cfg.max_position_embeddings;
|
||||
let embedding_dim = cfg.d_model;
|
||||
let weights = embedding(num_embeddings + offset, embedding_dim, vb)?;
|
||||
|
||||
Ok(Self { offset, weights })
|
||||
}
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
|
||||
let mut positions = Tensor::arange(
|
||||
past_key_values_length,
|
||||
seq_len as u32 + past_key_values_length,
|
||||
input_ids.device(),
|
||||
)?
|
||||
.expand((b_sz, seq_len))?;
|
||||
|
||||
positions =
|
||||
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
||||
self.weights.forward(&positions)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TrOCRAttention {
|
||||
head_dim: usize,
|
||||
num_heads: usize,
|
||||
is_decoder: bool,
|
||||
scaling: f64,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
q_proj: Linear,
|
||||
out_proj: Linear,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl TrOCRAttention {
|
||||
fn load(
|
||||
vb: VarBuilder,
|
||||
cfg: &TrOCRConfig,
|
||||
kdim: Option<usize>,
|
||||
vdim: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let embed_dim = cfg.d_model;
|
||||
let num_heads = cfg.decoder_attention_heads;
|
||||
let head_dim = embed_dim / num_heads;
|
||||
let kdim = kdim.unwrap_or(embed_dim);
|
||||
let vdim = vdim.unwrap_or(embed_dim);
|
||||
|
||||
let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?;
|
||||
let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?;
|
||||
|
||||
let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
head_dim,
|
||||
num_heads,
|
||||
is_decoder: true,
|
||||
scaling: 1. / (head_dim as f64).sqrt(),
|
||||
k_proj,
|
||||
v_proj,
|
||||
q_proj,
|
||||
out_proj,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
||||
tensor
|
||||
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
kv_states: Option<&Tensor>,
|
||||
attn_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||
let (key_states, value_states) = match kv_states {
|
||||
None => {
|
||||
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
|
||||
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
|
||||
if self.is_decoder {
|
||||
let kv_states = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((p_key_states, p_value_states)) => {
|
||||
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some(kv_states.clone());
|
||||
kv_states
|
||||
} else {
|
||||
(key_states, value_states)
|
||||
}
|
||||
}
|
||||
Some(kv_states) => {
|
||||
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
|
||||
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
|
||||
let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
|
||||
let key_states = key_states.reshape(proj_shape)?;
|
||||
let value_states = value_states.reshape(proj_shape)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
let attn_weights = match attn_mask {
|
||||
None => attn_weights,
|
||||
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
|
||||
};
|
||||
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_probs.matmul(&value_states)?;
|
||||
attn_output
|
||||
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
|
||||
.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TrOCRDecoderLayer {
|
||||
self_attn: TrOCRAttention,
|
||||
activation_fn: candle_nn::Activation,
|
||||
self_attn_layer_norm: LayerNorm,
|
||||
encoder_attn: TrOCRAttention,
|
||||
encoder_attn_layer_norm: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl TrOCRDecoderLayer {
|
||||
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
||||
let embed_dim = cfg.d_model;
|
||||
let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?;
|
||||
let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn = TrOCRAttention::load(
|
||||
vb.pp("encoder_attn"),
|
||||
cfg,
|
||||
Some(cfg.hidden_size),
|
||||
Some(cfg.hidden_size),
|
||||
)?;
|
||||
let encoder_attn_layer_norm =
|
||||
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
let activation_fn = candle_nn::Activation::Gelu;
|
||||
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
activation_fn,
|
||||
self_attn_layer_norm,
|
||||
encoder_attn,
|
||||
encoder_attn_layer_norm,
|
||||
fc1,
|
||||
fc2,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = self.self_attn.forward(xs, None, Some(attention_mask))?;
|
||||
let xs = (xs + residual)?;
|
||||
let mut xs = self.self_attn_layer_norm.forward(&xs)?;
|
||||
|
||||
if let Some(encoder_hidden_states) = &encoder_hidden_states {
|
||||
let residual = xs.clone();
|
||||
let encoder_attention_mask = attention_mask.clone(); // TODO
|
||||
xs = self.encoder_attn.forward(
|
||||
&xs,
|
||||
Some(encoder_hidden_states),
|
||||
Some(&encoder_attention_mask),
|
||||
)?;
|
||||
xs = (xs + residual)?;
|
||||
xs = self.encoder_attn_layer_norm.forward(&xs)?
|
||||
}
|
||||
|
||||
let residual = xs.clone();
|
||||
let xs = self.fc1.forward(&xs)?;
|
||||
let xs = self.activation_fn.forward(&xs)?;
|
||||
let xs = self.fc2.forward(&xs)?;
|
||||
let xs = (xs + residual)?;
|
||||
let xs = self.final_layer_norm.forward(&xs)?;
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrOCRDecoder {
|
||||
layers: Vec<TrOCRDecoderLayer>,
|
||||
embed_scale: Option<f64>,
|
||||
embed_tokens: Embedding,
|
||||
embed_positions: TrOCRLearnedPositionalEmbedding,
|
||||
}
|
||||
|
||||
impl TrOCRDecoder {
|
||||
fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("decoder.model.decoder");
|
||||
|
||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for idx in 0..cfg.decoder_layers {
|
||||
let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let embed_scale = if cfg.scale_embedding {
|
||||
Some((cfg.d_model as f64).sqrt())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
layers,
|
||||
embed_scale,
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: Option<&Tensor>,
|
||||
past_kv_len: usize,
|
||||
attn_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?;
|
||||
let xs = xs.apply(&self.embed_tokens)?;
|
||||
|
||||
let xs = match self.embed_scale {
|
||||
None => xs,
|
||||
Some(scale) => (xs * scale)?,
|
||||
};
|
||||
|
||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attn_mask, encoder_xs)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrOCREncoder {
|
||||
embeddings: Embeddings,
|
||||
encoder: Encoder,
|
||||
layernorm: LayerNorm,
|
||||
}
|
||||
|
||||
impl TrOCREncoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_v = vb.pp("encoder");
|
||||
|
||||
let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
|
||||
|
||||
let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
|
||||
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
|
||||
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(xs, None, false)?;
|
||||
let encoder_outputs = self.encoder.forward(&embedding_output)?;
|
||||
|
||||
self.layernorm.forward(&encoder_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrOCRForCausalLM {
|
||||
decoder: TrOCRDecoder,
|
||||
output_projection: Linear,
|
||||
}
|
||||
|
||||
impl TrOCRForCausalLM {
|
||||
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
||||
let output_projection =
|
||||
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
decoder,
|
||||
output_projection,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: Option<&Tensor>,
|
||||
past_kv_len: usize,
|
||||
attn_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let xs = self
|
||||
.decoder
|
||||
.forward(xs, encoder_xs, past_kv_len, attn_mask)?;
|
||||
let xs = xs.apply(&self.output_projection)?;
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrOCRModel {
|
||||
encoder: TrOCREncoder,
|
||||
decoder: TrOCRForCausalLM,
|
||||
}
|
||||
|
||||
impl TrOCRModel {
|
||||
pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?;
|
||||
let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?;
|
||||
Ok(Self { encoder, decoder })
|
||||
}
|
||||
|
||||
pub fn encoder(&mut self) -> &mut TrOCREncoder {
|
||||
&mut self.encoder
|
||||
}
|
||||
|
||||
pub fn decoder(&mut self) -> &mut TrOCRForCausalLM {
|
||||
&mut self.decoder
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: &Tensor,
|
||||
past_kv_len: usize,
|
||||
) -> Result<Tensor> {
|
||||
let seq_len = xs.dim(1)?;
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
|
||||
|
||||
self.decoder
|
||||
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
||||
}
|
||||
}
|
@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
hidden_size: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
intermediate_size: usize,
|
||||
hidden_act: candle_nn::Activation,
|
||||
layer_norm_eps: f64,
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
num_channels: usize,
|
||||
qkv_bias: bool,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
pub image_size: usize,
|
||||
pub patch_size: usize,
|
||||
pub num_channels: usize,
|
||||
pub qkv_bias: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -34,6 +34,21 @@ impl Config {
|
||||
qkv_bias: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn microsoft_trocr_base_handwritten() -> Self {
|
||||
Self {
|
||||
hidden_size: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 3072,
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
layer_norm_eps: 1e-12,
|
||||
image_size: 384,
|
||||
patch_size: 16,
|
||||
num_channels: 3,
|
||||
qkv_bias: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -76,7 +91,7 @@ impl Module for PatchEmbeddings {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Embeddings {
|
||||
pub struct Embeddings {
|
||||
cls_token: Tensor,
|
||||
mask_token: Option<Tensor>,
|
||||
patch_embeddings: PatchEmbeddings,
|
||||
@ -85,7 +100,7 @@ struct Embeddings {
|
||||
}
|
||||
|
||||
impl Embeddings {
|
||||
fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
|
||||
let mask_token = if use_mask_token {
|
||||
@ -115,7 +130,7 @@ impl Embeddings {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
pub fn forward(
|
||||
&self,
|
||||
pixel_values: &Tensor,
|
||||
bool_masked_pos: Option<&Tensor>,
|
||||
@ -324,12 +339,12 @@ impl Module for Layer {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Encoder {
|
||||
pub struct Encoder {
|
||||
layers: Vec<Layer>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
|
Reference in New Issue
Block a user