diff --git a/candle-examples/examples/jina-bert/main.rs b/candle-examples/examples/jina-bert/main.rs new file mode 100644 index 00000000..ffde777d --- /dev/null +++ b/candle-examples/examples/jina-bert/main.rs @@ -0,0 +1,162 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle_transformers::models::jina_bert::{BertModel, Config}; + +use anyhow::Error as E; +use candle::{DType, Module, Tensor}; +use candle_nn::VarBuilder; +use clap::Parser; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: Option, + + /// The number of times to run the prompt. + #[arg(long, default_value = "1")] + n: usize, + + /// L2 normalization for embeddings. + #[arg(long, default_value = "true")] + normalize_embeddings: bool, + + #[arg(long)] + tokenizer: String, + + #[arg(long)] + model: String, +} + +impl Args { + fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> { + let device = candle_examples::device(self.cpu)?; + let config = Config::v2_base(); + let tokenizer = tokenizers::Tokenizer::from_file(&self.tokenizer).map_err(E::msg)?; + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[&self.model], DType::F32, &device)? }; + let model = BertModel::new(vb, &config)?; + Ok((model, tokenizer)) + } +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let start = std::time::Instant::now(); + + let (model, mut tokenizer) = args.build_model_and_tokenizer()?; + let device = &model.device; + + if let Some(prompt) = args.prompt { + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + println!("Loaded and encoded {:?}", start.elapsed()); + for idx in 0..args.n { + let start = std::time::Instant::now(); + let ys = model.forward(&token_ids)?; + if idx == 0 { + println!("{ys}"); + } + println!("Took {:?}", start.elapsed()); + } + } else { + let sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ]; + let n_sentences = sentences.len(); + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = tokenizers::PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + } + let tokens = tokenizer + .encode_batch(sentences.to_vec(), true) + .map_err(E::msg)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::>>()?; + + let token_ids = Tensor::stack(&token_ids, 0)?; + println!("running inference on batch {:?}", token_ids.shape()); + let embeddings = model.forward(&token_ids)?; + println!("generated embeddings {:?}", embeddings.shape()); + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; + let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + let embeddings = if args.normalize_embeddings { + normalize_l2(&embeddings)? + } else { + embeddings + }; + println!("pooled embeddings {:?}", embeddings.shape()); + + let mut similarities = vec![]; + for i in 0..n_sentences { + let e_i = embeddings.get(i)?; + for j in (i + 1)..n_sentences { + let e_j = embeddings.get(j)?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; + let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); + similarities.push((cosine_similarity, i, j)) + } + } + similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); + for &(score, i, j) in similarities[..5].iter() { + println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) + } + } + Ok(()) +} + +pub fn normalize_l2(v: &Tensor) -> candle::Result { + v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) +} diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 8af34465..29e2904e 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -44,8 +44,10 @@ impl Linear { let span = tracing::span!(tracing::Level::TRACE, "linear"); Self { weight, bias, span } } +} - pub fn forward(&self, x: &Tensor) -> candle::Result { +impl Module for Linear { + fn forward(&self, x: &Tensor) -> candle::Result { let _enter = self.span.enter(); let w = match x.dims() { &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, @@ -77,8 +79,10 @@ impl LayerNorm { span, } } +} - pub fn forward(&self, x: &Tensor) -> Result { +impl Module for LayerNorm { + fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); let x_dtype = x.dtype(); let internal_dtype = match x_dtype { @@ -195,7 +199,9 @@ impl Dropout { fn new(pr: f64) -> Self { Self { pr } } +} +impl Module for Dropout { fn forward(&self, x: &Tensor) -> Result { // TODO Ok(x.clone()) @@ -316,7 +322,9 @@ impl BertSelfAttention { let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; xs.contiguous() } +} +impl Module for BertSelfAttention { fn forward(&self, hidden_states: &Tensor) -> Result { let _enter = self.span.enter(); let query_layer = self.query.forward(hidden_states)?; @@ -391,7 +399,9 @@ impl BertAttention { span: tracing::span!(tracing::Level::TRACE, "attn"), }) } +} +impl Module for BertAttention { fn forward(&self, hidden_states: &Tensor) -> Result { let _enter = self.span.enter(); let self_outputs = self.self_attention.forward(hidden_states)?; @@ -416,7 +426,9 @@ impl BertIntermediate { span: tracing::span!(tracing::Level::TRACE, "inter"), }) } +} +impl Module for BertIntermediate { fn forward(&self, hidden_states: &Tensor) -> Result { let _enter = self.span.enter(); let hidden_states = self.dense.forward(hidden_states)?; @@ -478,7 +490,9 @@ impl BertLayer { span: tracing::span!(tracing::Level::TRACE, "layer"), }) } +} +impl Module for BertLayer { fn forward(&self, hidden_states: &Tensor) -> Result { let _enter = self.span.enter(); let attention_output = self.attention.forward(hidden_states)?; @@ -507,7 +521,9 @@ impl BertEncoder { let span = tracing::span!(tracing::Level::TRACE, "encoder"); Ok(BertEncoder { layers, span }) } +} +impl Module for BertEncoder { fn forward(&self, hidden_states: &Tensor) -> Result { let _enter = self.span.enter(); let mut hidden_states = hidden_states.clone(); diff --git a/candle-transformers/src/models/jina_bert.rs b/candle-transformers/src/models/jina_bert.rs new file mode 100644 index 00000000..3f08eaea --- /dev/null +++ b/candle-transformers/src/models/jina_bert.rs @@ -0,0 +1,369 @@ +use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; +use serde::Deserialize; + +pub const DTYPE: DType = DType::F32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum PositionEmbeddingType { + Absolute, + Alibi, +} + +// https://huggingface.co/jinaai/jina-bert-implementation/blob/main/configuration_bert.py +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, + pub hidden_act: candle_nn::Activation, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub initializer_range: f64, + pub layer_norm_eps: f64, + pub pad_token_id: usize, + pub position_embedding_type: PositionEmbeddingType, +} + +impl Config { + pub fn v2_base() -> Self { + // https://huggingface.co/jinaai/jina-embeddings-v2-base-en/blob/main/config.json + Self { + vocab_size: 30528, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: candle_nn::Activation::Gelu, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Alibi, + } + } +} + +#[derive(Clone, Debug)] +struct BertEmbeddings { + word_embeddings: Embedding, + // no position_embeddings as we only support alibi. + token_type_embeddings: Embedding, + layer_norm: LayerNorm, + span: tracing::Span, +} + +impl BertEmbeddings { + fn new(vb: VarBuilder, cfg: &Config) -> Result { + let word_embeddings = + Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?; + let token_type_embeddings = Embedding::new( + cfg.type_vocab_size, + cfg.hidden_size, + vb.pp("token_type_embeddings"), + )?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { + word_embeddings, + token_type_embeddings, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } +} + +impl Module for BertEmbeddings { + fn forward(&self, input_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + let (b_size, seq_len) = input_ids.dims2()?; + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let token_type_embeddings = Tensor::zeros(seq_len, DType::U32, input_ids.device())? + .broadcast_left(b_size)? + .apply(&self.token_type_embeddings)?; + let embeddings = (&input_embeddings + token_type_embeddings)?; + let embeddings = self.layer_norm.forward(&embeddings)?; + Ok(embeddings) + } +} + +#[derive(Clone, Debug)] +struct BertSelfAttention { + query: Linear, + key: Linear, + value: Linear, + num_attention_heads: usize, + attention_head_size: usize, + span: tracing::Span, + span_softmax: tracing::Span, +} + +impl BertSelfAttention { + fn new(vb: VarBuilder, cfg: &Config) -> Result { + let attention_head_size = cfg.hidden_size / cfg.num_attention_heads; + let all_head_size = cfg.num_attention_heads * attention_head_size; + let hidden_size = cfg.hidden_size; + let query = linear(hidden_size, all_head_size, vb.pp("query"))?; + let value = linear(hidden_size, all_head_size, vb.pp("value"))?; + let key = linear(hidden_size, all_head_size, vb.pp("key"))?; + Ok(Self { + query, + key, + value, + num_attention_heads: cfg.num_attention_heads, + attention_head_size, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), + }) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result { + let mut x_shape = xs.dims().to_vec(); + x_shape.pop(); + x_shape.push(self.num_attention_heads); + x_shape.push(self.attention_head_size); + xs.reshape(x_shape)?.transpose(1, 2)?.contiguous() + } + + fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result { + let _enter = self.span.enter(); + let query_layer = self.query.forward(xs)?; + let key_layer = self.key.forward(xs)?; + let value_layer = self.value.forward(xs)?; + + let query_layer = self.transpose_for_scores(&query_layer)?; + let key_layer = self.transpose_for_scores(&key_layer)?; + let value_layer = self.transpose_for_scores(&value_layer)?; + + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_scores = attention_scores.broadcast_add(bias)?; + let attention_probs = { + let _enter_sm = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attention_scores)? + }; + let context_layer = attention_probs.matmul(&value_layer)?; + let context_layer = context_layer.transpose(1, 2)?.contiguous()?; + let context_layer = context_layer.flatten_from(D::Minus2)?; + Ok(context_layer) + } +} + +#[derive(Clone, Debug)] +struct BertSelfOutput { + dense: Linear, + layer_norm: LayerNorm, + span: tracing::Span, +} + +impl BertSelfOutput { + fn new(vb: VarBuilder, cfg: &Config) -> Result { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { + dense, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "self-out"), + }) + } + + fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); + let xs = self.dense.forward(xs)?; + self.layer_norm.forward(&(xs + input_tensor)?) + } +} + +#[derive(Clone, Debug)] +struct BertAttention { + self_attention: BertSelfAttention, + self_output: BertSelfOutput, + span: tracing::Span, +} + +impl BertAttention { + fn new(vb: VarBuilder, cfg: &Config) -> Result { + let self_attention = BertSelfAttention::new(vb.pp("self"), cfg)?; + let self_output = BertSelfOutput::new(vb.pp("output"), cfg)?; + Ok(Self { + self_attention, + self_output, + span: tracing::span!(tracing::Level::TRACE, "attn"), + }) + } + + fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result { + let _enter = self.span.enter(); + let self_outputs = self.self_attention.forward(xs, bias)?; + let attention_output = self.self_output.forward(&self_outputs, xs)?; + Ok(attention_output) + } +} + +#[derive(Clone, Debug)] +struct BertGLUMLP { + gated_layers: Linear, + act: candle_nn::Activation, + wo: Linear, + layernorm: LayerNorm, + intermediate_size: usize, +} + +impl BertGLUMLP { + fn new(vb: VarBuilder, cfg: &Config) -> Result { + let gated_layers = linear_no_bias( + cfg.hidden_size, + cfg.intermediate_size * 2, + vb.pp("gated_layers"), + )?; + let act = candle_nn::Activation::Gelu; // geglu + let wo = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("wo"))?; + let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?; + Ok(Self { + gated_layers, + act, + wo, + layernorm, + intermediate_size: cfg.intermediate_size, + }) + } +} + +impl Module for BertGLUMLP { + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + let xs = xs.apply(&self.gated_layers)?; + let gated = xs.narrow(D::Minus1, 0, self.intermediate_size)?; + let non_gated = xs.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; + let xs = (gated.apply(&self.act) * non_gated)?.apply(&self.wo); + (xs + residual)?.apply(&self.layernorm) + } +} + +#[derive(Clone, Debug)] +struct BertLayer { + attention: BertAttention, + mlp: BertGLUMLP, + span: tracing::Span, +} + +impl BertLayer { + fn new(vb: VarBuilder, cfg: &Config) -> Result { + let attention = BertAttention::new(vb.pp("attention"), cfg)?; + let mlp = BertGLUMLP::new(vb.pp("mlp"), cfg)?; + Ok(Self { + attention, + mlp, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result { + let _enter = self.span.enter(); + self.attention.forward(xs, bias)?.apply(&self.mlp) + } +} + +fn build_alibi_bias(cfg: &Config) -> Result { + let n_heads = cfg.num_attention_heads; + let seq_len = cfg.max_position_embeddings; + let alibi_bias = Tensor::arange(0, seq_len as i64, &Device::Cpu)?.to_dtype(DType::F32)?; + let alibi_bias = { + let a1 = alibi_bias.reshape((1, seq_len))?; + let a2 = alibi_bias.reshape((seq_len, 1))?; + a1.broadcast_sub(&a2)?.abs()?.broadcast_left(n_heads)? + }; + let mut n_heads2 = 1; + while n_heads2 < n_heads { + n_heads2 *= 2 + } + let slopes = (1..=n_heads2) + .map(|v| 1f32 / 2f32.powf(8f32 / v as f32)) + .collect::>(); + let slopes = if n_heads2 == n_heads { + slopes + } else { + slopes + .iter() + .skip(1) + .step_by(2) + .chain(slopes.iter().step_by(2)) + .take(n_heads) + .cloned() + .collect::>() + }; + let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?; + alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes) +} + +#[derive(Clone, Debug)] +struct BertEncoder { + alibi: Tensor, + layers: Vec, + span: tracing::Span, +} + +impl BertEncoder { + fn new(vb: VarBuilder, cfg: &Config) -> Result { + if cfg.position_embedding_type != PositionEmbeddingType::Alibi { + candle::bail!("only alibi is supported as a position-embedding-type") + } + let layers = (0..cfg.num_hidden_layers) + .map(|index| BertLayer::new(vb.pp(&format!("layer.{index}")), cfg)) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?; + Ok(Self { + alibi, + layers, + span, + }) + } +} + +impl Module for BertEncoder { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let seq_len = xs.dim(1)?; + let alibi_bias = self.alibi.i((.., .., ..seq_len, ..seq_len))?; + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, &alibi_bias)? + } + Ok(xs) + } +} + +#[derive(Clone, Debug)] +pub struct BertModel { + embeddings: BertEmbeddings, + encoder: BertEncoder, + pub device: Device, + span: tracing::Span, +} + +impl BertModel { + pub fn new(vb: VarBuilder, cfg: &Config) -> Result { + let embeddings = BertEmbeddings::new(vb.pp("embeddings"), cfg)?; + let encoder = BertEncoder::new(vb.pp("encoder"), cfg)?; + Ok(Self { + embeddings, + encoder, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } +} + +impl Module for BertModel { + fn forward(&self, input_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + let embedding_output = self.embeddings.forward(input_ids)?; + let sequence_output = self.encoder.forward(&embedding_output)?; + Ok(sequence_output) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index ce576c54..4e7c8bf0 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -6,6 +6,7 @@ pub mod convmixer; pub mod dinov2; pub mod efficientnet; pub mod falcon; +pub mod jina_bert; pub mod llama; pub mod mistral; pub mod mixformer;