Add support for TrOCR Model (#1303)

* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting

* add trocr model

* fix formatting

* commit the actual model lol

* more formatting

* remove tokenizer config
This commit is contained in:
Ogundepo Odunayo
2023-11-09 12:49:17 -05:00
committed by GitHub
parent e6697471bb
commit 6958384327
7 changed files with 767 additions and 15 deletions

View File

@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
#[derive(Debug, Clone)]
pub struct Config {
hidden_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
hidden_act: candle_nn::Activation,
layer_norm_eps: f64,
image_size: usize,
patch_size: usize,
num_channels: usize,
qkv_bias: bool,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
pub image_size: usize,
pub patch_size: usize,
pub num_channels: usize,
pub qkv_bias: bool,
}
impl Config {
@ -34,6 +34,21 @@ impl Config {
qkv_bias: true,
}
}
pub fn microsoft_trocr_base_handwritten() -> Self {
Self {
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: candle_nn::Activation::Gelu,
layer_norm_eps: 1e-12,
image_size: 384,
patch_size: 16,
num_channels: 3,
qkv_bias: false,
}
}
}
#[derive(Debug, Clone)]
@ -76,7 +91,7 @@ impl Module for PatchEmbeddings {
}
#[derive(Debug, Clone)]
struct Embeddings {
pub struct Embeddings {
cls_token: Tensor,
mask_token: Option<Tensor>,
patch_embeddings: PatchEmbeddings,
@ -85,7 +100,7 @@ struct Embeddings {
}
impl Embeddings {
fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
let mask_token = if use_mask_token {
@ -115,7 +130,7 @@ impl Embeddings {
todo!()
}
fn forward(
pub fn forward(
&self,
pixel_values: &Tensor,
bool_masked_pos: Option<&Tensor>,
@ -324,12 +339,12 @@ impl Module for Layer {
}
#[derive(Debug, Clone)]
struct Encoder {
pub struct Encoder {
layers: Vec<Layer>,
}
impl Encoder {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("layer");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {