mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Boilerplate code for Bert.
This commit is contained in:
239
candle-examples/examples/bert/main.rs
Normal file
239
candle-examples/examples/bert/main.rs
Normal file
@ -0,0 +1,239 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::Result as R;
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
impl HiddenAct {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum PositionEmbeddingType {
|
||||
Absolute,
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
struct Config {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
intermediate_size: usize,
|
||||
hidden_act: HiddenAct,
|
||||
hidden_dropout_prob: f64,
|
||||
max_position_embeddings: usize,
|
||||
type_vocab_size: usize,
|
||||
initializer_range: f64,
|
||||
layer_norm_eps: f64,
|
||||
pad_token_id: usize,
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
use_cache: bool,
|
||||
classifier_dropout: Option<f64>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 30522,
|
||||
hidden_size: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 3072,
|
||||
hidden_act: HiddenAct::Gelu,
|
||||
hidden_dropout_prob: 0.1,
|
||||
max_position_embeddings: 512,
|
||||
type_vocab_size: 2,
|
||||
initializer_range: 0.02,
|
||||
layer_norm_eps: 1e-12,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
classifier_dropout: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding {
|
||||
embeddings: Tensor,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(embeddings: Tensor) -> Self {
|
||||
Self { embeddings }
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
Tensor::embedding(indexes, &self.embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn new(weight: Tensor) -> Self {
|
||||
Self { weight }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight.t()?)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct Dropout {}
|
||||
|
||||
impl Dropout {
|
||||
fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
}
|
||||
}
|
||||
|
||||
struct LayerNorm {
|
||||
scale: Tensor,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
fn new(scale: Tensor) -> Self {
|
||||
Self { scale }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
let scale = self.scale.broadcast_as((seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
||||
struct BertEmbeddings {
|
||||
word_embeddings: Embedding,
|
||||
position_embeddings: Embedding,
|
||||
token_type_embeddings: Embedding,
|
||||
position_ids: Tensor,
|
||||
token_type_ids: Tensor,
|
||||
}
|
||||
|
||||
struct BertSelfAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
struct BertSelfOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl BertSelfOutput {
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
|
||||
struct BertAttention {
|
||||
self_attention: BertSelfAttention,
|
||||
self_output: BertSelfOutput,
|
||||
}
|
||||
|
||||
impl BertAttention {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
|
||||
struct BertIntermediate {
|
||||
dense: Linear,
|
||||
intermediate_act: HiddenAct,
|
||||
}
|
||||
|
||||
impl BertIntermediate {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
self.intermediate_act.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
|
||||
struct BertOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl BertOutput {
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
|
||||
struct BertLayer {
|
||||
attention: BertAttention,
|
||||
intermediate: BertIntermediate,
|
||||
output: BertOutput,
|
||||
}
|
||||
|
||||
impl BertLayer {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
||||
struct BertEncoder {
|
||||
layers: Vec<BertLayer>,
|
||||
}
|
||||
|
||||
impl BertEncoder {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
|
||||
struct BertModel {
|
||||
embeddings: BertEmbeddings,
|
||||
encoder: BertEncoder,
|
||||
}
|
||||
|
||||
impl BertModel {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> R<()> {
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user