mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Bert tracing (#184)
* Add some tracing to bert. * More tracing. * Add a flag for tracing.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -21,6 +21,7 @@ perf.data
|
|||||||
flamegraph.svg
|
flamegraph.svg
|
||||||
*.so
|
*.so
|
||||||
*.swp
|
*.swp
|
||||||
|
trace-*.json
|
||||||
|
|
||||||
candle-wasm-example/*.wav
|
candle-wasm-example/*.wav
|
||||||
candle-wasm-example/*.safetensors
|
candle-wasm-example/*.safetensors
|
||||||
|
@ -42,6 +42,9 @@ thiserror = "1"
|
|||||||
tokenizers = { version = "0.13.3", default-features = false, features = ["onig"] }
|
tokenizers = { version = "0.13.3", default-features = false, features = ["onig"] }
|
||||||
tokio = "1.28.2"
|
tokio = "1.28.2"
|
||||||
tokio-test = "0.4.2"
|
tokio-test = "0.4.2"
|
||||||
|
tracing = "0.1.37"
|
||||||
|
tracing-chrome = "0.7.1"
|
||||||
|
tracing-subscriber = "0.3.7"
|
||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
|
|
||||||
|
@ -25,6 +25,9 @@ candle-hub = { path = "../candle-hub" }
|
|||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
|
tracing-chrome = { workspace = true }
|
||||||
|
tracing-subscriber = { workspace = true }
|
||||||
wav = { workspace = true }
|
wav = { workspace = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
@ -1,471 +1,15 @@
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
mod model;
|
||||||
|
|
||||||
use anyhow::{anyhow, Error as E, Result};
|
use anyhow::{anyhow, Error as E, Result};
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::Tensor;
|
||||||
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||||
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use serde::Deserialize;
|
use model::{BertModel, Config, DTYPE};
|
||||||
use tokenizers::{PaddingParams, Tokenizer};
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
enum HiddenAct {
|
|
||||||
Gelu,
|
|
||||||
Relu,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HiddenAct {
|
|
||||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
// TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
|
|
||||||
// small numerical difference.
|
|
||||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
|
||||||
Self::Gelu => xs.gelu(),
|
|
||||||
Self::Relu => xs.relu(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
enum PositionEmbeddingType {
|
|
||||||
#[default]
|
|
||||||
Absolute,
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
|
||||||
struct Config {
|
|
||||||
vocab_size: usize,
|
|
||||||
hidden_size: usize,
|
|
||||||
num_hidden_layers: usize,
|
|
||||||
num_attention_heads: usize,
|
|
||||||
intermediate_size: usize,
|
|
||||||
hidden_act: HiddenAct,
|
|
||||||
hidden_dropout_prob: f64,
|
|
||||||
max_position_embeddings: usize,
|
|
||||||
type_vocab_size: usize,
|
|
||||||
initializer_range: f64,
|
|
||||||
layer_norm_eps: f64,
|
|
||||||
pad_token_id: usize,
|
|
||||||
#[serde(default)]
|
|
||||||
position_embedding_type: PositionEmbeddingType,
|
|
||||||
#[serde(default)]
|
|
||||||
use_cache: bool,
|
|
||||||
classifier_dropout: Option<f64>,
|
|
||||||
model_type: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Config {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
vocab_size: 30522,
|
|
||||||
hidden_size: 768,
|
|
||||||
num_hidden_layers: 12,
|
|
||||||
num_attention_heads: 12,
|
|
||||||
intermediate_size: 3072,
|
|
||||||
hidden_act: HiddenAct::Gelu,
|
|
||||||
hidden_dropout_prob: 0.1,
|
|
||||||
max_position_embeddings: 512,
|
|
||||||
type_vocab_size: 2,
|
|
||||||
initializer_range: 0.02,
|
|
||||||
layer_norm_eps: 1e-12,
|
|
||||||
pad_token_id: 0,
|
|
||||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
|
||||||
use_cache: true,
|
|
||||||
classifier_dropout: None,
|
|
||||||
model_type: Some("bert".to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
fn _all_mini_lm_l6_v2() -> Self {
|
|
||||||
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
|
|
||||||
Self {
|
|
||||||
vocab_size: 30522,
|
|
||||||
hidden_size: 384,
|
|
||||||
num_hidden_layers: 6,
|
|
||||||
num_attention_heads: 12,
|
|
||||||
intermediate_size: 1536,
|
|
||||||
hidden_act: HiddenAct::Gelu,
|
|
||||||
hidden_dropout_prob: 0.1,
|
|
||||||
max_position_embeddings: 512,
|
|
||||||
type_vocab_size: 2,
|
|
||||||
initializer_range: 0.02,
|
|
||||||
layer_norm_eps: 1e-12,
|
|
||||||
pad_token_id: 0,
|
|
||||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
|
||||||
use_cache: true,
|
|
||||||
classifier_dropout: None,
|
|
||||||
model_type: Some("bert".to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
|
||||||
let bias = vb.get(size2, "bias")?;
|
|
||||||
Ok(Linear::new(weight, Some(bias)))
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Dropout {
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pr: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Dropout {
|
|
||||||
fn new(pr: f64) -> Self {
|
|
||||||
Self { pr }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
// TODO
|
|
||||||
Ok(x.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
|
||||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
|
||||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
|
||||||
(weight, bias)
|
|
||||||
} else {
|
|
||||||
return Err(err.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(LayerNorm::new(weight, bias, eps))
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
|
||||||
struct BertEmbeddings {
|
|
||||||
word_embeddings: Embedding,
|
|
||||||
position_embeddings: Option<Embedding>,
|
|
||||||
token_type_embeddings: Embedding,
|
|
||||||
layer_norm: LayerNorm,
|
|
||||||
dropout: Dropout,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertEmbeddings {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let word_embeddings = embedding(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
vb.pp("word_embeddings"),
|
|
||||||
)?;
|
|
||||||
let position_embeddings = embedding(
|
|
||||||
config.max_position_embeddings,
|
|
||||||
config.hidden_size,
|
|
||||||
vb.pp("position_embeddings"),
|
|
||||||
)?;
|
|
||||||
let token_type_embeddings = embedding(
|
|
||||||
config.type_vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
vb.pp("token_type_embeddings"),
|
|
||||||
)?;
|
|
||||||
let layer_norm = layer_norm(
|
|
||||||
config.hidden_size,
|
|
||||||
config.layer_norm_eps,
|
|
||||||
vb.pp("LayerNorm"),
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
|
||||||
word_embeddings,
|
|
||||||
position_embeddings: Some(position_embeddings),
|
|
||||||
token_type_embeddings,
|
|
||||||
layer_norm,
|
|
||||||
dropout: Dropout::new(config.hidden_dropout_prob),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
|
||||||
let (_bsize, seq_len) = input_ids.shape().r2()?;
|
|
||||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
|
||||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
|
||||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
|
||||||
if let Some(position_embeddings) = &self.position_embeddings {
|
|
||||||
// TODO: Proper absolute positions?
|
|
||||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
|
||||||
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
|
||||||
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
|
|
||||||
}
|
|
||||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
|
||||||
let embeddings = self.dropout.forward(&embeddings)?;
|
|
||||||
Ok(embeddings)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct BertSelfAttention {
|
|
||||||
query: Linear,
|
|
||||||
key: Linear,
|
|
||||||
value: Linear,
|
|
||||||
dropout: Dropout,
|
|
||||||
num_attention_heads: usize,
|
|
||||||
attention_head_size: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertSelfAttention {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
|
||||||
let all_head_size = config.num_attention_heads * attention_head_size;
|
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
|
||||||
let hidden_size = config.hidden_size;
|
|
||||||
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
|
|
||||||
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
|
|
||||||
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
|
|
||||||
Ok(Self {
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
dropout,
|
|
||||||
num_attention_heads: config.num_attention_heads,
|
|
||||||
attention_head_size,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut new_x_shape = xs.dims().to_vec();
|
|
||||||
new_x_shape.pop();
|
|
||||||
new_x_shape.push(self.num_attention_heads);
|
|
||||||
new_x_shape.push(self.attention_head_size);
|
|
||||||
// Be cautious about the transposition if adding a batch dim!
|
|
||||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
|
||||||
Ok(xs.contiguous()?)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let query_layer = self.query.forward(hidden_states)?;
|
|
||||||
let key_layer = self.key.forward(hidden_states)?;
|
|
||||||
let value_layer = self.value.forward(hidden_states)?;
|
|
||||||
|
|
||||||
let query_layer = self.transpose_for_scores(&query_layer)?;
|
|
||||||
let key_layer = self.transpose_for_scores(&key_layer)?;
|
|
||||||
let value_layer = self.transpose_for_scores(&value_layer)?;
|
|
||||||
|
|
||||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
|
||||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
|
||||||
let attention_probs = attention_scores.softmax(candle::D::Minus1)?;
|
|
||||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
|
||||||
|
|
||||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
|
||||||
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
|
||||||
let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
|
|
||||||
Ok(context_layer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct BertSelfOutput {
|
|
||||||
dense: Linear,
|
|
||||||
layer_norm: LayerNorm,
|
|
||||||
dropout: Dropout,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertSelfOutput {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
|
|
||||||
let layer_norm = layer_norm(
|
|
||||||
config.hidden_size,
|
|
||||||
config.layer_norm_eps,
|
|
||||||
vb.pp("LayerNorm"),
|
|
||||||
)?;
|
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
|
||||||
Ok(Self {
|
|
||||||
dense,
|
|
||||||
layer_norm,
|
|
||||||
dropout,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
|
||||||
let hidden_states = self.dense.forward(hidden_states)?;
|
|
||||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
|
||||||
Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
|
|
||||||
struct BertAttention {
|
|
||||||
self_attention: BertSelfAttention,
|
|
||||||
self_output: BertSelfOutput,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertAttention {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
|
|
||||||
let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
|
|
||||||
Ok(Self {
|
|
||||||
self_attention,
|
|
||||||
self_output,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
|
||||||
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
|
||||||
Ok(attention_output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
|
|
||||||
struct BertIntermediate {
|
|
||||||
dense: Linear,
|
|
||||||
intermediate_act: HiddenAct,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertIntermediate {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
|
|
||||||
Ok(Self {
|
|
||||||
dense,
|
|
||||||
intermediate_act: config.hidden_act,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let hidden_states = self.dense.forward(hidden_states)?;
|
|
||||||
let ys = self.intermediate_act.forward(&hidden_states)?;
|
|
||||||
Ok(ys)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
|
|
||||||
struct BertOutput {
|
|
||||||
dense: Linear,
|
|
||||||
layer_norm: LayerNorm,
|
|
||||||
dropout: Dropout,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertOutput {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
|
|
||||||
let layer_norm = layer_norm(
|
|
||||||
config.hidden_size,
|
|
||||||
config.layer_norm_eps,
|
|
||||||
vb.pp("LayerNorm"),
|
|
||||||
)?;
|
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
|
||||||
Ok(Self {
|
|
||||||
dense,
|
|
||||||
layer_norm,
|
|
||||||
dropout,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
|
||||||
let hidden_states = self.dense.forward(hidden_states)?;
|
|
||||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
|
||||||
Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
|
|
||||||
struct BertLayer {
|
|
||||||
attention: BertAttention,
|
|
||||||
intermediate: BertIntermediate,
|
|
||||||
output: BertOutput,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertLayer {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let attention = BertAttention::load(vb.pp("attention"), config)?;
|
|
||||||
let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
|
|
||||||
let output = BertOutput::load(vb.pp("output"), config)?;
|
|
||||||
Ok(Self {
|
|
||||||
attention,
|
|
||||||
intermediate,
|
|
||||||
output,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let attention_output = self.attention.forward(hidden_states)?;
|
|
||||||
// TODO: Support cross-attention?
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
|
||||||
// TODO: Support something similar to `apply_chunking_to_forward`?
|
|
||||||
let intermediate_output = self.intermediate.forward(&attention_output)?;
|
|
||||||
let layer_output = self
|
|
||||||
.output
|
|
||||||
.forward(&intermediate_output, &attention_output)?;
|
|
||||||
Ok(layer_output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
|
||||||
struct BertEncoder {
|
|
||||||
layers: Vec<BertLayer>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertEncoder {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let layers = (0..config.num_hidden_layers)
|
|
||||||
.map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Ok(BertEncoder { layers })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut hidden_states = hidden_states.clone();
|
|
||||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
|
||||||
for layer in self.layers.iter() {
|
|
||||||
hidden_states = layer.forward(&hidden_states)?
|
|
||||||
}
|
|
||||||
Ok(hidden_states)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
|
|
||||||
struct BertModel {
|
|
||||||
embeddings: BertEmbeddings,
|
|
||||||
encoder: BertEncoder,
|
|
||||||
device: Device,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertModel {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let (embeddings, encoder) = match (
|
|
||||||
BertEmbeddings::load(vb.pp("embeddings"), config),
|
|
||||||
BertEncoder::load(vb.pp("encoder"), config),
|
|
||||||
) {
|
|
||||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
|
||||||
if let Some(model_type) = &config.model_type {
|
|
||||||
if let (Ok(embeddings), Ok(encoder)) = (
|
|
||||||
BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
|
||||||
BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
|
|
||||||
) {
|
|
||||||
(embeddings, encoder)
|
|
||||||
} else {
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(Self {
|
|
||||||
embeddings,
|
|
||||||
encoder,
|
|
||||||
device: vb.device().clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
|
||||||
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
|
|
||||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
|
||||||
Ok(sequence_output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -477,6 +21,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
offline: bool,
|
offline: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
@ -540,9 +88,20 @@ impl Args {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let start = std::time::Instant::now();
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
println!("tracing...");
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
let device = &model.device;
|
let device = &model.device;
|
||||||
|
|
||||||
|
525
candle-examples/examples/bert/model.rs
Normal file
525
candle-examples/examples/bert/model.rs
Normal file
@ -0,0 +1,525 @@
|
|||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
use candle_nn::{Embedding, LayerNorm, VarBuilder};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
pub const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
enum HiddenAct {
|
||||||
|
Gelu,
|
||||||
|
Relu,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HiddenActLayer {
|
||||||
|
act: HiddenAct,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HiddenActLayer {
|
||||||
|
fn new(act: HiddenAct) -> Self {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
|
||||||
|
Self { act, span }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
match self.act {
|
||||||
|
// TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
|
||||||
|
// small numerical difference.
|
||||||
|
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||||
|
HiddenAct::Gelu => xs.gelu(),
|
||||||
|
HiddenAct::Relu => xs.relu(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Linear {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Self { weight, bias, span }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let w = match x.dims() {
|
||||||
|
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => self.weight.t()?,
|
||||||
|
};
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
|
match &self.bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
enum PositionEmbeddingType {
|
||||||
|
#[default]
|
||||||
|
Absolute,
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
vocab_size: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
num_hidden_layers: usize,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
intermediate_size: usize,
|
||||||
|
hidden_act: HiddenAct,
|
||||||
|
hidden_dropout_prob: f64,
|
||||||
|
max_position_embeddings: usize,
|
||||||
|
type_vocab_size: usize,
|
||||||
|
initializer_range: f64,
|
||||||
|
layer_norm_eps: f64,
|
||||||
|
pad_token_id: usize,
|
||||||
|
#[serde(default)]
|
||||||
|
position_embedding_type: PositionEmbeddingType,
|
||||||
|
#[serde(default)]
|
||||||
|
use_cache: bool,
|
||||||
|
classifier_dropout: Option<f64>,
|
||||||
|
model_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 30522,
|
||||||
|
hidden_size: 768,
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
intermediate_size: 3072,
|
||||||
|
hidden_act: HiddenAct::Gelu,
|
||||||
|
hidden_dropout_prob: 0.1,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
type_vocab_size: 2,
|
||||||
|
initializer_range: 0.02,
|
||||||
|
layer_norm_eps: 1e-12,
|
||||||
|
pad_token_id: 0,
|
||||||
|
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||||
|
use_cache: true,
|
||||||
|
classifier_dropout: None,
|
||||||
|
model_type: Some("bert".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
fn _all_mini_lm_l6_v2() -> Self {
|
||||||
|
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
|
||||||
|
Self {
|
||||||
|
vocab_size: 30522,
|
||||||
|
hidden_size: 384,
|
||||||
|
num_hidden_layers: 6,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
intermediate_size: 1536,
|
||||||
|
hidden_act: HiddenAct::Gelu,
|
||||||
|
hidden_dropout_prob: 0.1,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
type_vocab_size: 2,
|
||||||
|
initializer_range: 0.02,
|
||||||
|
layer_norm_eps: 1e-12,
|
||||||
|
pad_token_id: 0,
|
||||||
|
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||||
|
use_cache: true,
|
||||||
|
classifier_dropout: None,
|
||||||
|
model_type: Some("bert".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||||
|
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||||
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let weight = vb.get((size2, size1), "weight")?;
|
||||||
|
let bias = vb.get(size2, "bias")?;
|
||||||
|
Ok(Linear::new(weight, Some(bias)))
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Dropout {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pr: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Dropout {
|
||||||
|
fn new(pr: f64) -> Self {
|
||||||
|
Self { pr }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
// TODO
|
||||||
|
Ok(x.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
|
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||||
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
|
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||||
|
(weight, bias)
|
||||||
|
} else {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(LayerNorm::new(weight, bias, eps))
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
||||||
|
struct BertEmbeddings {
|
||||||
|
word_embeddings: Embedding,
|
||||||
|
position_embeddings: Option<Embedding>,
|
||||||
|
token_type_embeddings: Embedding,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
dropout: Dropout,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertEmbeddings {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let word_embeddings = embedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
vb.pp("word_embeddings"),
|
||||||
|
)?;
|
||||||
|
let position_embeddings = embedding(
|
||||||
|
config.max_position_embeddings,
|
||||||
|
config.hidden_size,
|
||||||
|
vb.pp("position_embeddings"),
|
||||||
|
)?;
|
||||||
|
let token_type_embeddings = embedding(
|
||||||
|
config.type_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
vb.pp("token_type_embeddings"),
|
||||||
|
)?;
|
||||||
|
let layer_norm = layer_norm(
|
||||||
|
config.hidden_size,
|
||||||
|
config.layer_norm_eps,
|
||||||
|
vb.pp("LayerNorm"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
word_embeddings,
|
||||||
|
position_embeddings: Some(position_embeddings),
|
||||||
|
token_type_embeddings,
|
||||||
|
layer_norm,
|
||||||
|
dropout: Dropout::new(config.hidden_dropout_prob),
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let (_bsize, seq_len) = input_ids.shape().r2()?;
|
||||||
|
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||||
|
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||||
|
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||||
|
if let Some(position_embeddings) = &self.position_embeddings {
|
||||||
|
// TODO: Proper absolute positions?
|
||||||
|
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||||
|
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
||||||
|
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
|
||||||
|
}
|
||||||
|
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||||
|
let embeddings = self.dropout.forward(&embeddings)?;
|
||||||
|
Ok(embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BertSelfAttention {
|
||||||
|
query: Linear,
|
||||||
|
key: Linear,
|
||||||
|
value: Linear,
|
||||||
|
dropout: Dropout,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
attention_head_size: usize,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertSelfAttention {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||||
|
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||||
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
let hidden_size = config.hidden_size;
|
||||||
|
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
|
||||||
|
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
|
||||||
|
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
|
||||||
|
Ok(Self {
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
dropout,
|
||||||
|
num_attention_heads: config.num_attention_heads,
|
||||||
|
attention_head_size,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut new_x_shape = xs.dims().to_vec();
|
||||||
|
new_x_shape.pop();
|
||||||
|
new_x_shape.push(self.num_attention_heads);
|
||||||
|
new_x_shape.push(self.attention_head_size);
|
||||||
|
// Be cautious about the transposition if adding a batch dim!
|
||||||
|
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||||
|
xs.contiguous()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let query_layer = self.query.forward(hidden_states)?;
|
||||||
|
let key_layer = self.key.forward(hidden_states)?;
|
||||||
|
let value_layer = self.value.forward(hidden_states)?;
|
||||||
|
|
||||||
|
let query_layer = self.transpose_for_scores(&query_layer)?;
|
||||||
|
let key_layer = self.transpose_for_scores(&key_layer)?;
|
||||||
|
let value_layer = self.transpose_for_scores(&value_layer)?;
|
||||||
|
|
||||||
|
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||||
|
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||||
|
let attention_probs = attention_scores.softmax(candle::D::Minus1)?;
|
||||||
|
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||||
|
|
||||||
|
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||||
|
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||||
|
let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
|
||||||
|
Ok(context_layer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BertSelfOutput {
|
||||||
|
dense: Linear,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
dropout: Dropout,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertSelfOutput {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
|
||||||
|
let layer_norm = layer_norm(
|
||||||
|
config.hidden_size,
|
||||||
|
config.layer_norm_eps,
|
||||||
|
vb.pp("LayerNorm"),
|
||||||
|
)?;
|
||||||
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "self-out"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let hidden_states = self.dense.forward(hidden_states)?;
|
||||||
|
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||||
|
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
|
||||||
|
struct BertAttention {
|
||||||
|
self_attention: BertSelfAttention,
|
||||||
|
self_output: BertSelfOutput,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertAttention {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
|
||||||
|
let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attention,
|
||||||
|
self_output,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let self_outputs = self.self_attention.forward(hidden_states)?;
|
||||||
|
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
||||||
|
Ok(attention_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
|
||||||
|
struct BertIntermediate {
|
||||||
|
dense: Linear,
|
||||||
|
intermediate_act: HiddenActLayer,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertIntermediate {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
intermediate_act: HiddenActLayer::new(config.hidden_act),
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "inter"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let hidden_states = self.dense.forward(hidden_states)?;
|
||||||
|
let ys = self.intermediate_act.forward(&hidden_states)?;
|
||||||
|
Ok(ys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
|
||||||
|
struct BertOutput {
|
||||||
|
dense: Linear,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
dropout: Dropout,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertOutput {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
|
||||||
|
let layer_norm = layer_norm(
|
||||||
|
config.hidden_size,
|
||||||
|
config.layer_norm_eps,
|
||||||
|
vb.pp("LayerNorm"),
|
||||||
|
)?;
|
||||||
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "out"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let hidden_states = self.dense.forward(hidden_states)?;
|
||||||
|
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||||
|
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
|
||||||
|
struct BertLayer {
|
||||||
|
attention: BertAttention,
|
||||||
|
intermediate: BertIntermediate,
|
||||||
|
output: BertOutput,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertLayer {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let attention = BertAttention::load(vb.pp("attention"), config)?;
|
||||||
|
let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
|
||||||
|
let output = BertOutput::load(vb.pp("output"), config)?;
|
||||||
|
Ok(Self {
|
||||||
|
attention,
|
||||||
|
intermediate,
|
||||||
|
output,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let attention_output = self.attention.forward(hidden_states)?;
|
||||||
|
// TODO: Support cross-attention?
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||||
|
// TODO: Support something similar to `apply_chunking_to_forward`?
|
||||||
|
let intermediate_output = self.intermediate.forward(&attention_output)?;
|
||||||
|
let layer_output = self
|
||||||
|
.output
|
||||||
|
.forward(&intermediate_output, &attention_output)?;
|
||||||
|
Ok(layer_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
||||||
|
struct BertEncoder {
|
||||||
|
layers: Vec<BertLayer>,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertEncoder {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let layers = (0..config.num_hidden_layers)
|
||||||
|
.map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||||
|
Ok(BertEncoder { layers, span })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let mut hidden_states = hidden_states.clone();
|
||||||
|
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
hidden_states = layer.forward(&hidden_states)?
|
||||||
|
}
|
||||||
|
Ok(hidden_states)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
|
||||||
|
pub struct BertModel {
|
||||||
|
embeddings: BertEmbeddings,
|
||||||
|
encoder: BertEncoder,
|
||||||
|
pub device: Device,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertModel {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let (embeddings, encoder) = match (
|
||||||
|
BertEmbeddings::load(vb.pp("embeddings"), config),
|
||||||
|
BertEncoder::load(vb.pp("encoder"), config),
|
||||||
|
) {
|
||||||
|
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||||
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
|
if let Some(model_type) = &config.model_type {
|
||||||
|
if let (Ok(embeddings), Ok(encoder)) = (
|
||||||
|
BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||||
|
BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
|
||||||
|
) {
|
||||||
|
(embeddings, encoder)
|
||||||
|
} else {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
|
||||||
|
let sequence_output = self.encoder.forward(&embedding_output)?;
|
||||||
|
Ok(sequence_output)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user