diff --git a/candle-examples/examples/distilbert/README.md b/candle-examples/examples/distilbert/README.md new file mode 100644 index 00000000..88f97f2b --- /dev/null +++ b/candle-examples/examples/distilbert/README.md @@ -0,0 +1,22 @@ +# candle-distilbert + +DistilBert is a distiled version of the Bert model. + +## Sentence embeddings + +DistilBert is used to compute the sentence embeddings for a prompt. The model weights +are downloaded from the hub on the first run. + +```bash +cargo run --example distilbert --release -- --prompt "Here is a test sentence" + +> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441], +> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244], +> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829], +> ... +> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042], +> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190], +> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]] +> Tensor[[1, 7, 768], f32] + +``` diff --git a/candle-examples/examples/distilbert/main.rs b/candle-examples/examples/distilbert/main.rs new file mode 100644 index 00000000..1d42011c --- /dev/null +++ b/candle-examples/examples/distilbert/main.rs @@ -0,0 +1,135 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE}; + +use anyhow::{Error as E, Result}; +use candle::{Device, Tensor}; +use candle_nn::VarBuilder; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: String, + + /// Use the pytorch weights rather than the safetensors ones + #[arg(long)] + use_pth: bool, + + /// The number of times to run the prompt. + #[arg(long, default_value = "1")] + n: usize, + + /// L2 normalization for embeddings. + #[arg(long, default_value = "true")] + normalize_embeddings: bool, +} + +impl Args { + fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> { + let device = candle_examples::device(self.cpu)?; + let default_model = "distilbert-base-uncased".to_string(); + let default_revision = "main".to_string(); + let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new()?; + let api = api.repo(repo); + let config = api.get("config.json")?; + let tokenizer = api.get("tokenizer.json")?; + let weights = if self.use_pth { + api.get("pytorch_model.bin")? + } else { + api.get("model.safetensors")? + }; + (config, tokenizer, weights) + }; + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let vb = if self.use_pth { + VarBuilder::from_pth(&weights_filename, DTYPE, &device)? + } else { + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } + }; + let model = DistilBertModel::load(vb, &config)?; + Ok((model, tokenizer)) + } +} + +fn get_mask(size: usize, device: &Device) -> Tensor { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device).unwrap() +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let (model, mut tokenizer) = args.build_model_and_tokenizer()?; + let device = &model.device; + + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(args.prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let mask = get_mask(tokens.len(), device); + + println!("token_ids: {:?}", token_ids.to_vec2::()); + println!("mask: {:?}", mask.to_vec2::()); + + let ys = model.forward(&token_ids, &mask)?; + println!("{ys}"); + + Ok(()) +} + +pub fn normalize_l2(v: &Tensor) -> Result { + Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) +} diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs new file mode 100644 index 00000000..ea074c97 --- /dev/null +++ b/candle-transformers/src/models/distilbert.rs @@ -0,0 +1,342 @@ +use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; +use candle::{DType, Device, Result, Tensor}; +use candle_nn::{Embedding, Module, VarBuilder}; +use serde::Deserialize; + +pub const DTYPE: DType = DType::F32; + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +enum HiddenAct { + Gelu, + Relu, +} + +struct HiddenActLayer { + act: HiddenAct, + span: tracing::Span, +} + +impl HiddenActLayer { + fn new(act: HiddenAct) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); + Self { act, span } + } +} + +impl Module for HiddenActLayer { + fn forward(&self, xs: &Tensor) -> candle::Result { + let _enter = self.span.enter(); + match self.act { + // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 + HiddenAct::Gelu => xs.gelu(), + HiddenAct::Relu => xs.relu(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +enum PositionEmbeddingType { + #[default] + Absolute, +} + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + dim: usize, + n_layers: usize, + n_heads: usize, + hidden_dim: usize, + activation: HiddenAct, + max_position_embeddings: usize, + initializer_range: f64, + pad_token_id: usize, + #[serde(default)] + position_embedding_type: PositionEmbeddingType, + #[serde(default)] + use_cache: bool, + model_type: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 30522, + dim: 768, + n_layers: 12, + n_heads: 12, + hidden_dim: 3072, + activation: HiddenAct::Gelu, + max_position_embeddings: 512, + initializer_range: 0.02, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + model_type: Some("distilbert".to_string()), + } + } +} + +struct Embeddings { + word_embeddings: Embedding, + position_embeddings: Embedding, + layer_norm: LayerNorm, + span: tracing::Span, +} + +impl Embeddings { + fn load(vb: VarBuilder, config: &Config) -> Result { + let word_embeddings = + candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?; + let position_embeddings = candle_nn::embedding( + config.max_position_embeddings, + config.dim, + vb.pp("position_embeddings"), + )?; + let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?; + Ok(Self { + word_embeddings, + position_embeddings, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + fn forward(&self, input_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + let (_bsize, seq_len) = input_ids.dims2()?; + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let position_ids = (0..seq_len as u32).collect::>(); + let position_ids = Tensor::new(&position_ids[..], input_ids.device())?; + let embeddings = + input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?; + + let embeddings = self.layer_norm.forward(&embeddings)?; + Ok(embeddings) + } +} + +struct MultiHeadSelfAttention { + q_lin: Linear, + k_lin: Linear, + v_lin: Linear, + out_lin: Linear, + n_heads: usize, + attention_head_size: usize, + span: tracing::Span, +} + +impl MultiHeadSelfAttention { + fn load(vb: VarBuilder, config: &Config) -> Result { + let attention_head_size = config.dim / config.n_heads; + let all_head_size = config.n_heads * attention_head_size; + let dim = config.dim; + let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?; + let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?; + let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?; + let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?; + Ok(Self { + q_lin, + k_lin, + v_lin, + out_lin, + n_heads: config.n_heads, + attention_head_size, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } +} + +impl MultiHeadSelfAttention { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let (bs, q_length, _dim) = hidden_states.dims3()?; + + let dim_per_head = self.attention_head_size; + let q = self.q_lin.forward(hidden_states)?; + let k = self.k_lin.forward(hidden_states)?; + let v = self.v_lin.forward(hidden_states)?; + + let q = q + .reshape((bs, q_length, self.n_heads, dim_per_head))? + .transpose(1, 2)?; + let k = k + .reshape((bs, q_length, self.n_heads, dim_per_head))? + .transpose(1, 2)?; + let v = v + .reshape((bs, q_length, self.n_heads, dim_per_head))? + .transpose(1, 2)?; + + let q: Tensor = (q / (dim_per_head as f64).sqrt())?; + let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?; + let mask = attention_mask.broadcast_as(scores.shape())?; + + let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?; + let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?; + + let context = weights.matmul(&v.contiguous()?)?; + let context = context + .transpose(1, 2)? + .reshape((bs, q_length, self.n_heads * dim_per_head))? + .contiguous()?; + let context = self.out_lin.forward(&context)?; + + Ok(context) + } +} + +#[allow(clippy::upper_case_acronyms)] +struct FFN { + lin1: Linear, + lin2: Linear, + activation: HiddenActLayer, + span: tracing::Span, +} + +impl FFN { + fn load(vb: VarBuilder, config: &Config) -> Result { + let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?; + let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?; + Ok(Self { + lin1, + lin2, + activation: HiddenActLayer::new(config.activation), + span: tracing::span!(tracing::Level::TRACE, "ffn"), + }) + } +} + +impl Module for FFN { + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + hidden_states + .apply(&self.lin1)? + .apply(&self.activation)? + .apply(&self.lin2) + } +} + +struct TransformerBlock { + attention: MultiHeadSelfAttention, + sa_layer_norm: LayerNorm, + ffn: FFN, + output_layer_norm: LayerNorm, + span: tracing::Span, +} + +impl TransformerBlock { + fn load(vb: VarBuilder, config: &Config) -> Result { + let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?; + let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?; + let ffn = FFN::load(vb.pp("ffn"), config)?; + let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?; + Ok(Self { + attention, + sa_layer_norm, + ffn, + output_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } +} + +impl TransformerBlock { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let sa_output = self.attention.forward(hidden_states, attention_mask)?; + // TODO: Support cross-attention? + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + // TODO: Support something similar to `apply_chunking_to_forward`? + let sa_output = sa_output.broadcast_add(hidden_states)?; + let sa_output = self.sa_layer_norm.forward(&sa_output)?; + + let ffn_output = self.ffn.forward(&sa_output)?; + let ffn_output = (&ffn_output + sa_output)?; + let output = self.output_layer_norm.forward(&ffn_output)?; + Ok(output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 +struct Transformer { + layers: Vec, + span: tracing::Span, +} + +impl Transformer { + fn load(vb: VarBuilder, config: &Config) -> Result { + let layers = (0..config.n_layers) + .map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config)) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + Ok(Transformer { layers, span }) + } +} + +impl Transformer { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let mut hidden_states = hidden_states.clone(); + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states, attention_mask)?; + } + Ok(hidden_states) + } +} + +pub struct DistilBertModel { + embeddings: Embeddings, + transformer: Transformer, + pub device: Device, + span: tracing::Span, +} + +impl DistilBertModel { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let (embeddings, transformer) = match ( + Embeddings::load(vb.pp("embeddings"), config), + Transformer::load(vb.pp("transformer"), config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let Some(model_type) = &config.model_type { + if let (Ok(embeddings), Ok(encoder)) = ( + Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), + Transformer::load(vb.pp(&format!("{model_type}.transformer")), config), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } else { + return Err(err); + } + } + }; + Ok(Self { + embeddings, + transformer, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let embedding_output = self.embeddings.forward(input_ids)?; + let sequence_output = self + .transformer + .forward(&embedding_output, attention_mask)?; + Ok(sequence_output) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 558583b6..a9a56673 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -4,6 +4,7 @@ pub mod blip; pub mod blip_text; pub mod convmixer; pub mod dinov2; +pub mod distilbert; pub mod efficientnet; pub mod falcon; pub mod jina_bert;