Files
candle/candle-examples/examples/bert/main.rs
2023-07-03 12:17:06 +01:00

240 lines
6.0 KiB
Rust

#![allow(dead_code)]
use anyhow::Result as R;
use candle::{Result, Tensor};
#[derive(Debug, Clone, PartialEq, Eq)]
enum HiddenAct {
Gelu,
Relu,
}
impl HiddenAct {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum PositionEmbeddingType {
Absolute,
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
#[derive(Debug, Clone, PartialEq)]
struct Config {
vocab_size: usize,
hidden_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
hidden_act: HiddenAct,
hidden_dropout_prob: f64,
max_position_embeddings: usize,
type_vocab_size: usize,
initializer_range: f64,
layer_norm_eps: f64,
pad_token_id: usize,
position_embedding_type: PositionEmbeddingType,
use_cache: bool,
classifier_dropout: Option<f64>,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 30522,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: HiddenAct::Gelu,
hidden_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
classifier_dropout: None,
}
}
}
struct Embedding {
embeddings: Tensor,
}
impl Embedding {
fn new(embeddings: Tensor) -> Self {
Self { embeddings }
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Tensor::embedding(indexes, &self.embeddings)
}
}
struct Linear {
weight: Tensor,
}
impl Linear {
fn new(weight: Tensor) -> Self {
Self { weight }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.weight.t()?)?;
Ok(x)
}
}
struct Dropout {}
impl Dropout {
fn new() -> Self {
Self {}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO
Ok(x.clone())
}
}
struct LayerNorm {
scale: Tensor,
}
impl LayerNorm {
fn new(scale: Tensor) -> Self {
Self { scale }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (seq_len, hidden_size) = x.shape().r2()?;
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1()?;
let scale = self.scale.broadcast_as((seq_len, size))?;
let x = (scale * x_normed)?;
Ok(x)
}
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
struct BertEmbeddings {
word_embeddings: Embedding,
position_embeddings: Embedding,
token_type_embeddings: Embedding,
position_ids: Tensor,
token_type_ids: Tensor,
}
struct BertSelfAttention {
query: Linear,
key: Linear,
value: Linear,
dropout: Dropout,
}
struct BertSelfOutput {
dense: Linear,
layer_norm: LayerNorm,
dropout: Dropout,
}
impl BertSelfOutput {
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states)?;
self.layer_norm.forward(&(hidden_states + input_tensor)?)
}
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
struct BertAttention {
self_attention: BertSelfAttention,
self_output: BertSelfOutput,
}
impl BertAttention {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
struct BertIntermediate {
dense: Linear,
intermediate_act: HiddenAct,
}
impl BertIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
self.intermediate_act.forward(&hidden_states)
}
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
struct BertOutput {
dense: Linear,
layer_norm: LayerNorm,
dropout: Dropout,
}
impl BertOutput {
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states)?;
self.layer_norm.forward(&(hidden_states + input_tensor)?)
}
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
struct BertLayer {
attention: BertAttention,
intermediate: BertIntermediate,
output: BertOutput,
}
impl BertLayer {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
struct BertEncoder {
layers: Vec<BertLayer>,
}
impl BertEncoder {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
struct BertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
}
impl BertModel {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
fn main() -> R<()> {
Ok(())
}