From f0cccd08f0cf66ac6a93049785249cc113514c8a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 17 Jul 2023 19:40:42 +0100 Subject: [PATCH] Bert tracing (#184) * Add some tracing to bert. * More tracing. * Add a flag for tracing. --- .gitignore | 1 + Cargo.toml | 3 + candle-examples/Cargo.toml | 3 + candle-examples/examples/bert/main.rs | 481 +--------------------- candle-examples/examples/bert/model.rs | 525 +++++++++++++++++++++++++ 5 files changed, 552 insertions(+), 461 deletions(-) create mode 100644 candle-examples/examples/bert/model.rs diff --git a/.gitignore b/.gitignore index e8d63fbb..df9a6132 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ perf.data flamegraph.svg *.so *.swp +trace-*.json candle-wasm-example/*.wav candle-wasm-example/*.safetensors diff --git a/Cargo.toml b/Cargo.toml index ebb4ba5d..d8efc38b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,9 @@ thiserror = "1" tokenizers = { version = "0.13.3", default-features = false, features = ["onig"] } tokio = "1.28.2" tokio-test = "0.4.2" +tracing = "0.1.37" +tracing-chrome = "0.7.1" +tracing-subscriber = "0.3.7" wav = "1.0.0" zip = { version = "0.6.6", default-features = false } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index d2618864..29451030 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -25,6 +25,9 @@ candle-hub = { path = "../candle-hub" } clap = { workspace = true } rand = { workspace = true } tokenizers = { workspace = true } +tracing = { workspace = true } +tracing-chrome = { workspace = true } +tracing-subscriber = { workspace = true } wav = { workspace = true } [features] diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index dca6721b..958b70b1 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,471 +1,15 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; +mod model; use anyhow::{anyhow, Error as E, Result}; -use candle::{DType, Device, Tensor}; +use candle::Tensor; use candle_hub::{api::sync::Api, Cache, Repo, RepoType}; -use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::VarBuilder; use clap::Parser; -use serde::Deserialize; +use model::{BertModel, Config, DTYPE}; use tokenizers::{PaddingParams, Tokenizer}; -const DTYPE: DType = DType::F32; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] -#[serde(rename_all = "lowercase")] -enum HiddenAct { - Gelu, - Relu, -} - -impl HiddenAct { - fn forward(&self, xs: &Tensor) -> candle::Result { - match self { - // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some - // small numerical difference. - // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 - Self::Gelu => xs.gelu(), - Self::Relu => xs.relu(), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] -#[serde(rename_all = "lowercase")] -enum PositionEmbeddingType { - #[default] - Absolute, -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 -#[derive(Debug, Clone, PartialEq, Deserialize)] -struct Config { - vocab_size: usize, - hidden_size: usize, - num_hidden_layers: usize, - num_attention_heads: usize, - intermediate_size: usize, - hidden_act: HiddenAct, - hidden_dropout_prob: f64, - max_position_embeddings: usize, - type_vocab_size: usize, - initializer_range: f64, - layer_norm_eps: f64, - pad_token_id: usize, - #[serde(default)] - position_embedding_type: PositionEmbeddingType, - #[serde(default)] - use_cache: bool, - classifier_dropout: Option, - model_type: Option, -} - -impl Default for Config { - fn default() -> Self { - Self { - vocab_size: 30522, - hidden_size: 768, - num_hidden_layers: 12, - num_attention_heads: 12, - intermediate_size: 3072, - hidden_act: HiddenAct::Gelu, - hidden_dropout_prob: 0.1, - max_position_embeddings: 512, - type_vocab_size: 2, - initializer_range: 0.02, - layer_norm_eps: 1e-12, - pad_token_id: 0, - position_embedding_type: PositionEmbeddingType::Absolute, - use_cache: true, - classifier_dropout: None, - model_type: Some("bert".to_string()), - } - } -} - -impl Config { - fn _all_mini_lm_l6_v2() -> Self { - // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json - Self { - vocab_size: 30522, - hidden_size: 384, - num_hidden_layers: 6, - num_attention_heads: 12, - intermediate_size: 1536, - hidden_act: HiddenAct::Gelu, - hidden_dropout_prob: 0.1, - max_position_embeddings: 512, - type_vocab_size: 2, - initializer_range: 0.02, - layer_norm_eps: 1e-12, - pad_token_id: 0, - position_embedding_type: PositionEmbeddingType::Absolute, - use_cache: true, - classifier_dropout: None, - model_type: Some("bert".to_string()), - } - } -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - let bias = vb.get(size2, "bias")?; - Ok(Linear::new(weight, Some(bias))) -} - -struct Dropout { - #[allow(dead_code)] - pr: f64, -} - -impl Dropout { - fn new(pr: f64) -> Self { - Self { pr } - } - - fn forward(&self, x: &Tensor) -> Result { - // TODO - Ok(x.clone()) - } -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { - let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { - (weight, bias) - } else { - return Err(err.into()); - } - } - }; - Ok(LayerNorm::new(weight, bias, eps)) -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 -struct BertEmbeddings { - word_embeddings: Embedding, - position_embeddings: Option, - token_type_embeddings: Embedding, - layer_norm: LayerNorm, - dropout: Dropout, -} - -impl BertEmbeddings { - fn load(vb: VarBuilder, config: &Config) -> Result { - let word_embeddings = embedding( - config.vocab_size, - config.hidden_size, - vb.pp("word_embeddings"), - )?; - let position_embeddings = embedding( - config.max_position_embeddings, - config.hidden_size, - vb.pp("position_embeddings"), - )?; - let token_type_embeddings = embedding( - config.type_vocab_size, - config.hidden_size, - vb.pp("token_type_embeddings"), - )?; - let layer_norm = layer_norm( - config.hidden_size, - config.layer_norm_eps, - vb.pp("LayerNorm"), - )?; - Ok(Self { - word_embeddings, - position_embeddings: Some(position_embeddings), - token_type_embeddings, - layer_norm, - dropout: Dropout::new(config.hidden_dropout_prob), - }) - } - - fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { - let (_bsize, seq_len) = input_ids.shape().r2()?; - let input_embeddings = self.word_embeddings.forward(input_ids)?; - let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; - let mut embeddings = (&input_embeddings + token_type_embeddings)?; - if let Some(position_embeddings) = &self.position_embeddings { - // TODO: Proper absolute positions? - let position_ids = (0..seq_len as u32).collect::>(); - let position_ids = Tensor::new(&position_ids[..], input_ids.device())?; - embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? - } - let embeddings = self.layer_norm.forward(&embeddings)?; - let embeddings = self.dropout.forward(&embeddings)?; - Ok(embeddings) - } -} - -struct BertSelfAttention { - query: Linear, - key: Linear, - value: Linear, - dropout: Dropout, - num_attention_heads: usize, - attention_head_size: usize, -} - -impl BertSelfAttention { - fn load(vb: VarBuilder, config: &Config) -> Result { - let attention_head_size = config.hidden_size / config.num_attention_heads; - let all_head_size = config.num_attention_heads * attention_head_size; - let dropout = Dropout::new(config.hidden_dropout_prob); - let hidden_size = config.hidden_size; - let query = linear(hidden_size, all_head_size, vb.pp("query"))?; - let value = linear(hidden_size, all_head_size, vb.pp("value"))?; - let key = linear(hidden_size, all_head_size, vb.pp("key"))?; - Ok(Self { - query, - key, - value, - dropout, - num_attention_heads: config.num_attention_heads, - attention_head_size, - }) - } - - fn transpose_for_scores(&self, xs: &Tensor) -> Result { - let mut new_x_shape = xs.dims().to_vec(); - new_x_shape.pop(); - new_x_shape.push(self.num_attention_heads); - new_x_shape.push(self.attention_head_size); - // Be cautious about the transposition if adding a batch dim! - let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; - Ok(xs.contiguous()?) - } - - fn forward(&self, hidden_states: &Tensor) -> Result { - let query_layer = self.query.forward(hidden_states)?; - let key_layer = self.key.forward(hidden_states)?; - let value_layer = self.value.forward(hidden_states)?; - - let query_layer = self.transpose_for_scores(&query_layer)?; - let key_layer = self.transpose_for_scores(&key_layer)?; - let value_layer = self.transpose_for_scores(&value_layer)?; - - let attention_scores = query_layer.matmul(&key_layer.t()?)?; - let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; - let attention_probs = attention_scores.softmax(candle::D::Minus1)?; - let attention_probs = self.dropout.forward(&attention_probs)?; - - let context_layer = attention_probs.matmul(&value_layer)?; - let context_layer = context_layer.transpose(1, 2)?.contiguous()?; - let context_layer = context_layer.flatten_from(candle::D::Minus2)?; - Ok(context_layer) - } -} - -struct BertSelfOutput { - dense: Linear, - layer_norm: LayerNorm, - dropout: Dropout, -} - -impl BertSelfOutput { - fn load(vb: VarBuilder, config: &Config) -> Result { - let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; - let layer_norm = layer_norm( - config.hidden_size, - config.layer_norm_eps, - vb.pp("LayerNorm"), - )?; - let dropout = Dropout::new(config.hidden_dropout_prob); - Ok(Self { - dense, - layer_norm, - dropout, - }) - } - - fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { - let hidden_states = self.dense.forward(hidden_states)?; - let hidden_states = self.dropout.forward(&hidden_states)?; - Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 -struct BertAttention { - self_attention: BertSelfAttention, - self_output: BertSelfOutput, -} - -impl BertAttention { - fn load(vb: VarBuilder, config: &Config) -> Result { - let self_attention = BertSelfAttention::load(vb.pp("self"), config)?; - let self_output = BertSelfOutput::load(vb.pp("output"), config)?; - Ok(Self { - self_attention, - self_output, - }) - } - - fn forward(&self, hidden_states: &Tensor) -> Result { - let self_outputs = self.self_attention.forward(hidden_states)?; - let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; - Ok(attention_output) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 -struct BertIntermediate { - dense: Linear, - intermediate_act: HiddenAct, -} - -impl BertIntermediate { - fn load(vb: VarBuilder, config: &Config) -> Result { - let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; - Ok(Self { - dense, - intermediate_act: config.hidden_act, - }) - } - - fn forward(&self, hidden_states: &Tensor) -> Result { - let hidden_states = self.dense.forward(hidden_states)?; - let ys = self.intermediate_act.forward(&hidden_states)?; - Ok(ys) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 -struct BertOutput { - dense: Linear, - layer_norm: LayerNorm, - dropout: Dropout, -} - -impl BertOutput { - fn load(vb: VarBuilder, config: &Config) -> Result { - let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; - let layer_norm = layer_norm( - config.hidden_size, - config.layer_norm_eps, - vb.pp("LayerNorm"), - )?; - let dropout = Dropout::new(config.hidden_dropout_prob); - Ok(Self { - dense, - layer_norm, - dropout, - }) - } - - fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { - let hidden_states = self.dense.forward(hidden_states)?; - let hidden_states = self.dropout.forward(&hidden_states)?; - Ok(self.layer_norm.forward(&(hidden_states + input_tensor)?)?) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 -struct BertLayer { - attention: BertAttention, - intermediate: BertIntermediate, - output: BertOutput, -} - -impl BertLayer { - fn load(vb: VarBuilder, config: &Config) -> Result { - let attention = BertAttention::load(vb.pp("attention"), config)?; - let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?; - let output = BertOutput::load(vb.pp("output"), config)?; - Ok(Self { - attention, - intermediate, - output, - }) - } - - fn forward(&self, hidden_states: &Tensor) -> Result { - let attention_output = self.attention.forward(hidden_states)?; - // TODO: Support cross-attention? - // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 - // TODO: Support something similar to `apply_chunking_to_forward`? - let intermediate_output = self.intermediate.forward(&attention_output)?; - let layer_output = self - .output - .forward(&intermediate_output, &attention_output)?; - Ok(layer_output) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 -struct BertEncoder { - layers: Vec, -} - -impl BertEncoder { - fn load(vb: VarBuilder, config: &Config) -> Result { - let layers = (0..config.num_hidden_layers) - .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config)) - .collect::>>()?; - Ok(BertEncoder { layers }) - } - - fn forward(&self, hidden_states: &Tensor) -> Result { - let mut hidden_states = hidden_states.clone(); - // Use a loop rather than a fold as it's easier to modify when adding debug/... - for layer in self.layers.iter() { - hidden_states = layer.forward(&hidden_states)? - } - Ok(hidden_states) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 -struct BertModel { - embeddings: BertEmbeddings, - encoder: BertEncoder, - device: Device, -} - -impl BertModel { - fn load(vb: VarBuilder, config: &Config) -> Result { - let (embeddings, encoder) = match ( - BertEmbeddings::load(vb.pp("embeddings"), config), - BertEncoder::load(vb.pp("encoder"), config), - ) { - (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), - (Err(err), _) | (_, Err(err)) => { - if let Some(model_type) = &config.model_type { - if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), - BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), - ) { - (embeddings, encoder) - } else { - return Err(err); - } - } else { - return Err(err); - } - } - }; - Ok(Self { - embeddings, - encoder, - device: vb.device().clone(), - }) - } - - fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { - let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; - let sequence_output = self.encoder.forward(&embedding_output)?; - Ok(sequence_output) - } -} - #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -477,6 +21,10 @@ struct Args { #[arg(long)] offline: bool, + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending #[arg(long)] model_id: Option, @@ -540,9 +88,20 @@ impl Args { } fn main() -> Result<()> { - let start = std::time::Instant::now(); + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let start = std::time::Instant::now(); + let (model, mut tokenizer) = args.build_model_and_tokenizer()?; let device = &model.device; diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs new file mode 100644 index 00000000..059f4280 --- /dev/null +++ b/candle-examples/examples/bert/model.rs @@ -0,0 +1,525 @@ +use candle::{DType, Device, Result, Tensor}; +use candle_nn::{Embedding, LayerNorm, VarBuilder}; +use serde::Deserialize; + +pub const DTYPE: DType = DType::F32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +enum HiddenAct { + Gelu, + Relu, +} + +struct HiddenActLayer { + act: HiddenAct, + span: tracing::Span, +} + +impl HiddenActLayer { + fn new(act: HiddenAct) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); + Self { act, span } + } + + fn forward(&self, xs: &Tensor) -> candle::Result { + let _enter = self.span.enter(); + match self.act { + // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some + // small numerical difference. + // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 + HiddenAct::Gelu => xs.gelu(), + HiddenAct::Relu => xs.relu(), + } + } +} + +#[derive(Debug)] +pub struct Linear { + weight: Tensor, + bias: Option, + span: tracing::Span, +} + +impl Linear { + pub fn new(weight: Tensor, bias: Option) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Self { weight, bias, span } + } + + pub fn forward(&self, x: &Tensor) -> candle::Result { + let _enter = self.span.enter(); + let w = match x.dims() { + &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + _ => self.weight.t()?, + }; + let x = x.matmul(&w)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +enum PositionEmbeddingType { + #[default] + Absolute, +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + hidden_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + intermediate_size: usize, + hidden_act: HiddenAct, + hidden_dropout_prob: f64, + max_position_embeddings: usize, + type_vocab_size: usize, + initializer_range: f64, + layer_norm_eps: f64, + pad_token_id: usize, + #[serde(default)] + position_embedding_type: PositionEmbeddingType, + #[serde(default)] + use_cache: bool, + classifier_dropout: Option, + model_type: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 30522, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + model_type: Some("bert".to_string()), + } + } +} + +impl Config { + fn _all_mini_lm_l6_v2() -> Self { + // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json + Self { + vocab_size: 30522, + hidden_size: 384, + num_hidden_layers: 6, + num_attention_heads: 12, + intermediate_size: 1536, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + model_type: Some("bert".to_string()), + } + } +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; + Ok(Linear::new(weight, Some(bias))) +} + +struct Dropout { + #[allow(dead_code)] + pr: f64, +} + +impl Dropout { + fn new(pr: f64) -> Self { + Self { pr } + } + + fn forward(&self, x: &Tensor) -> Result { + // TODO + Ok(x.clone()) + } +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { + (weight, bias) + } else { + return Err(err); + } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 +struct BertEmbeddings { + word_embeddings: Embedding, + position_embeddings: Option, + token_type_embeddings: Embedding, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertEmbeddings { + fn load(vb: VarBuilder, config: &Config) -> Result { + let word_embeddings = embedding( + config.vocab_size, + config.hidden_size, + vb.pp("word_embeddings"), + )?; + let position_embeddings = embedding( + config.max_position_embeddings, + config.hidden_size, + vb.pp("position_embeddings"), + )?; + let token_type_embeddings = embedding( + config.type_vocab_size, + config.hidden_size, + vb.pp("token_type_embeddings"), + )?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + Ok(Self { + word_embeddings, + position_embeddings: Some(position_embeddings), + token_type_embeddings, + layer_norm, + dropout: Dropout::new(config.hidden_dropout_prob), + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + let (_bsize, seq_len) = input_ids.shape().r2()?; + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; + let mut embeddings = (&input_embeddings + token_type_embeddings)?; + if let Some(position_embeddings) = &self.position_embeddings { + // TODO: Proper absolute positions? + let position_ids = (0..seq_len as u32).collect::>(); + let position_ids = Tensor::new(&position_ids[..], input_ids.device())?; + embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? + } + let embeddings = self.layer_norm.forward(&embeddings)?; + let embeddings = self.dropout.forward(&embeddings)?; + Ok(embeddings) + } +} + +struct BertSelfAttention { + query: Linear, + key: Linear, + value: Linear, + dropout: Dropout, + num_attention_heads: usize, + attention_head_size: usize, + span: tracing::Span, +} + +impl BertSelfAttention { + fn load(vb: VarBuilder, config: &Config) -> Result { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + let dropout = Dropout::new(config.hidden_dropout_prob); + let hidden_size = config.hidden_size; + let query = linear(hidden_size, all_head_size, vb.pp("query"))?; + let value = linear(hidden_size, all_head_size, vb.pp("value"))?; + let key = linear(hidden_size, all_head_size, vb.pp("key"))?; + Ok(Self { + query, + key, + value, + dropout, + num_attention_heads: config.num_attention_heads, + attention_head_size, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + }) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result { + let mut new_x_shape = xs.dims().to_vec(); + new_x_shape.pop(); + new_x_shape.push(self.num_attention_heads); + new_x_shape.push(self.attention_head_size); + // Be cautious about the transposition if adding a batch dim! + let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; + xs.contiguous() + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let query_layer = self.query.forward(hidden_states)?; + let key_layer = self.key.forward(hidden_states)?; + let value_layer = self.value.forward(hidden_states)?; + + let query_layer = self.transpose_for_scores(&query_layer)?; + let key_layer = self.transpose_for_scores(&key_layer)?; + let value_layer = self.transpose_for_scores(&value_layer)?; + + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_probs = attention_scores.softmax(candle::D::Minus1)?; + let attention_probs = self.dropout.forward(&attention_probs)?; + + let context_layer = attention_probs.matmul(&value_layer)?; + let context_layer = context_layer.transpose(1, 2)?.contiguous()?; + let context_layer = context_layer.flatten_from(candle::D::Minus2)?; + Ok(context_layer) + } +} + +struct BertSelfOutput { + dense: Linear, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertSelfOutput { + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "self-out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 +struct BertAttention { + self_attention: BertSelfAttention, + self_output: BertSelfOutput, + span: tracing::Span, +} + +impl BertAttention { + fn load(vb: VarBuilder, config: &Config) -> Result { + let self_attention = BertSelfAttention::load(vb.pp("self"), config)?; + let self_output = BertSelfOutput::load(vb.pp("output"), config)?; + Ok(Self { + self_attention, + self_output, + span: tracing::span!(tracing::Level::TRACE, "attn"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let self_outputs = self.self_attention.forward(hidden_states)?; + let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; + Ok(attention_output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 +struct BertIntermediate { + dense: Linear, + intermediate_act: HiddenActLayer, + span: tracing::Span, +} + +impl BertIntermediate { + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; + Ok(Self { + dense, + intermediate_act: HiddenActLayer::new(config.hidden_act), + span: tracing::span!(tracing::Level::TRACE, "inter"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let ys = self.intermediate_act.forward(&hidden_states)?; + Ok(ys) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 +struct BertOutput { + dense: Linear, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertOutput { + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 +struct BertLayer { + attention: BertAttention, + intermediate: BertIntermediate, + output: BertOutput, + span: tracing::Span, +} + +impl BertLayer { + fn load(vb: VarBuilder, config: &Config) -> Result { + let attention = BertAttention::load(vb.pp("attention"), config)?; + let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?; + let output = BertOutput::load(vb.pp("output"), config)?; + Ok(Self { + attention, + intermediate, + output, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let attention_output = self.attention.forward(hidden_states)?; + // TODO: Support cross-attention? + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + // TODO: Support something similar to `apply_chunking_to_forward`? + let intermediate_output = self.intermediate.forward(&attention_output)?; + let layer_output = self + .output + .forward(&intermediate_output, &attention_output)?; + Ok(layer_output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 +struct BertEncoder { + layers: Vec, + span: tracing::Span, +} + +impl BertEncoder { + fn load(vb: VarBuilder, config: &Config) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config)) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + Ok(BertEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let mut hidden_states = hidden_states.clone(); + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states)? + } + Ok(hidden_states) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 +pub struct BertModel { + embeddings: BertEmbeddings, + encoder: BertEncoder, + pub device: Device, + span: tracing::Span, +} + +impl BertModel { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let (embeddings, encoder) = match ( + BertEmbeddings::load(vb.pp("embeddings"), config), + BertEncoder::load(vb.pp("encoder"), config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let Some(model_type) = &config.model_type { + if let (Ok(embeddings), Ok(encoder)) = ( + BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), + BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } else { + return Err(err); + } + } + }; + Ok(Self { + embeddings, + encoder, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; + let sequence_output = self.encoder.forward(&embedding_output)?; + Ok(sequence_output) + } +}