mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Add support for TrOCR Model (#1303)
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting * add trocr model * fix formatting * commit the actual model lol * more formatting * remove tokenizer config
This commit is contained in:
@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
hidden_size: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
intermediate_size: usize,
|
||||
hidden_act: candle_nn::Activation,
|
||||
layer_norm_eps: f64,
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
num_channels: usize,
|
||||
qkv_bias: bool,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
pub image_size: usize,
|
||||
pub patch_size: usize,
|
||||
pub num_channels: usize,
|
||||
pub qkv_bias: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -34,6 +34,21 @@ impl Config {
|
||||
qkv_bias: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn microsoft_trocr_base_handwritten() -> Self {
|
||||
Self {
|
||||
hidden_size: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 3072,
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
layer_norm_eps: 1e-12,
|
||||
image_size: 384,
|
||||
patch_size: 16,
|
||||
num_channels: 3,
|
||||
qkv_bias: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -76,7 +91,7 @@ impl Module for PatchEmbeddings {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Embeddings {
|
||||
pub struct Embeddings {
|
||||
cls_token: Tensor,
|
||||
mask_token: Option<Tensor>,
|
||||
patch_embeddings: PatchEmbeddings,
|
||||
@ -85,7 +100,7 @@ struct Embeddings {
|
||||
}
|
||||
|
||||
impl Embeddings {
|
||||
fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
|
||||
let mask_token = if use_mask_token {
|
||||
@ -115,7 +130,7 @@ impl Embeddings {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
pub fn forward(
|
||||
&self,
|
||||
pixel_values: &Tensor,
|
||||
bool_masked_pos: Option<&Tensor>,
|
||||
@ -324,12 +339,12 @@ impl Module for Layer {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Encoder {
|
||||
pub struct Encoder {
|
||||
layers: Vec<Layer>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
|
Reference in New Issue
Block a user