Add the jina-bert embeddings model. (#1187)

* Add the jina-bert model.

* Use alibi.

* Remove the unused pragma.

* Recompute the alibi embeddings.

* Generate the token type ids.

* Use the module trait.

* Add the jina-bert example.

* DType fix.

* Get the inference to work.
This commit is contained in:
Laurent Mazare
2023-10-26 16:54:36 +01:00
committed by GitHub
parent e37b487767
commit 5f20697918
4 changed files with 550 additions and 2 deletions

View File

@ -0,0 +1,162 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::jina_bert::{BertModel, Config};
use anyhow::Error as E;
use candle::{DType, Module, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
#[arg(long)]
tokenizer: String,
#[arg(long)]
model: String,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(&self.tokenizer).map_err(E::msg)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[&self.model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer))
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
} else {
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = embeddings.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}

View File

@ -44,8 +44,10 @@ impl Linear {
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { weight, bias, span }
}
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
@ -77,8 +79,10 @@ impl LayerNorm {
span,
}
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
@ -195,7 +199,9 @@ impl Dropout {
fn new(pr: f64) -> Self {
Self { pr }
}
}
impl Module for Dropout {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO
Ok(x.clone())
@ -316,7 +322,9 @@ impl BertSelfAttention {
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous()
}
}
impl Module for BertSelfAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(hidden_states)?;
@ -391,7 +399,9 @@ impl BertAttention {
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
}
impl Module for BertAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(hidden_states)?;
@ -416,7 +426,9 @@ impl BertIntermediate {
span: tracing::span!(tracing::Level::TRACE, "inter"),
})
}
}
impl Module for BertIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
@ -478,7 +490,9 @@ impl BertLayer {
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
}
impl Module for BertLayer {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let attention_output = self.attention.forward(hidden_states)?;
@ -507,7 +521,9 @@ impl BertEncoder {
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
}
}
impl Module for BertEncoder {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();

View File

@ -0,0 +1,369 @@
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PositionEmbeddingType {
Absolute,
Alibi,
}
// https://huggingface.co/jinaai/jina-bert-implementation/blob/main/configuration_bert.py
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub hidden_act: candle_nn::Activation,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub initializer_range: f64,
pub layer_norm_eps: f64,
pub pad_token_id: usize,
pub position_embedding_type: PositionEmbeddingType,
}
impl Config {
pub fn v2_base() -> Self {
// https://huggingface.co/jinaai/jina-embeddings-v2-base-en/blob/main/config.json
Self {
vocab_size: 30528,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: candle_nn::Activation::Gelu,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Alibi,
}
}
}
#[derive(Clone, Debug)]
struct BertEmbeddings {
word_embeddings: Embedding,
// no position_embeddings as we only support alibi.
token_type_embeddings: Embedding,
layer_norm: LayerNorm,
span: tracing::Span,
}
impl BertEmbeddings {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let word_embeddings =
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
let token_type_embeddings = Embedding::new(
cfg.type_vocab_size,
cfg.hidden_size,
vb.pp("token_type_embeddings"),
)?;
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
Ok(Self {
word_embeddings,
token_type_embeddings,
layer_norm,
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
})
}
}
impl Module for BertEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_size, seq_len) = input_ids.dims2()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = Tensor::zeros(seq_len, DType::U32, input_ids.device())?
.broadcast_left(b_size)?
.apply(&self.token_type_embeddings)?;
let embeddings = (&input_embeddings + token_type_embeddings)?;
let embeddings = self.layer_norm.forward(&embeddings)?;
Ok(embeddings)
}
}
#[derive(Clone, Debug)]
struct BertSelfAttention {
query: Linear,
key: Linear,
value: Linear,
num_attention_heads: usize,
attention_head_size: usize,
span: tracing::Span,
span_softmax: tracing::Span,
}
impl BertSelfAttention {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
let all_head_size = cfg.num_attention_heads * attention_head_size;
let hidden_size = cfg.hidden_size;
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
Ok(Self {
query,
key,
value,
num_attention_heads: cfg.num_attention_heads,
attention_head_size,
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
})
}
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
let mut x_shape = xs.dims().to_vec();
x_shape.pop();
x_shape.push(self.num_attention_heads);
x_shape.push(self.attention_head_size);
xs.reshape(x_shape)?.transpose(1, 2)?.contiguous()
}
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(xs)?;
let key_layer = self.key.forward(xs)?;
let value_layer = self.value.forward(xs)?;
let query_layer = self.transpose_for_scores(&query_layer)?;
let key_layer = self.transpose_for_scores(&key_layer)?;
let value_layer = self.transpose_for_scores(&value_layer)?;
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_scores = attention_scores.broadcast_add(bias)?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attention_scores)?
};
let context_layer = attention_probs.matmul(&value_layer)?;
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
let context_layer = context_layer.flatten_from(D::Minus2)?;
Ok(context_layer)
}
}
#[derive(Clone, Debug)]
struct BertSelfOutput {
dense: Linear,
layer_norm: LayerNorm,
span: tracing::Span,
}
impl BertSelfOutput {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
Ok(Self {
dense,
layer_norm,
span: tracing::span!(tracing::Level::TRACE, "self-out"),
})
}
fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.dense.forward(xs)?;
self.layer_norm.forward(&(xs + input_tensor)?)
}
}
#[derive(Clone, Debug)]
struct BertAttention {
self_attention: BertSelfAttention,
self_output: BertSelfOutput,
span: tracing::Span,
}
impl BertAttention {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::new(vb.pp("self"), cfg)?;
let self_output = BertSelfOutput::new(vb.pp("output"), cfg)?;
Ok(Self {
self_attention,
self_output,
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(xs, bias)?;
let attention_output = self.self_output.forward(&self_outputs, xs)?;
Ok(attention_output)
}
}
#[derive(Clone, Debug)]
struct BertGLUMLP {
gated_layers: Linear,
act: candle_nn::Activation,
wo: Linear,
layernorm: LayerNorm,
intermediate_size: usize,
}
impl BertGLUMLP {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let gated_layers = linear_no_bias(
cfg.hidden_size,
cfg.intermediate_size * 2,
vb.pp("gated_layers"),
)?;
let act = candle_nn::Activation::Gelu; // geglu
let wo = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("wo"))?;
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
Ok(Self {
gated_layers,
act,
wo,
layernorm,
intermediate_size: cfg.intermediate_size,
})
}
}
impl Module for BertGLUMLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.gated_layers)?;
let gated = xs.narrow(D::Minus1, 0, self.intermediate_size)?;
let non_gated = xs.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
let xs = (gated.apply(&self.act) * non_gated)?.apply(&self.wo);
(xs + residual)?.apply(&self.layernorm)
}
}
#[derive(Clone, Debug)]
struct BertLayer {
attention: BertAttention,
mlp: BertGLUMLP,
span: tracing::Span,
}
impl BertLayer {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attention = BertAttention::new(vb.pp("attention"), cfg)?;
let mlp = BertGLUMLP::new(vb.pp("mlp"), cfg)?;
Ok(Self {
attention,
mlp,
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.attention.forward(xs, bias)?.apply(&self.mlp)
}
}
fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
let n_heads = cfg.num_attention_heads;
let seq_len = cfg.max_position_embeddings;
let alibi_bias = Tensor::arange(0, seq_len as i64, &Device::Cpu)?.to_dtype(DType::F32)?;
let alibi_bias = {
let a1 = alibi_bias.reshape((1, seq_len))?;
let a2 = alibi_bias.reshape((seq_len, 1))?;
a1.broadcast_sub(&a2)?.abs()?.broadcast_left(n_heads)?
};
let mut n_heads2 = 1;
while n_heads2 < n_heads {
n_heads2 *= 2
}
let slopes = (1..=n_heads2)
.map(|v| 1f32 / 2f32.powf(8f32 / v as f32))
.collect::<Vec<_>>();
let slopes = if n_heads2 == n_heads {
slopes
} else {
slopes
.iter()
.skip(1)
.step_by(2)
.chain(slopes.iter().step_by(2))
.take(n_heads)
.cloned()
.collect::<Vec<f32>>()
};
let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
}
#[derive(Clone, Debug)]
struct BertEncoder {
alibi: Tensor,
layers: Vec<BertLayer>,
span: tracing::Span,
}
impl BertEncoder {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
if cfg.position_embedding_type != PositionEmbeddingType::Alibi {
candle::bail!("only alibi is supported as a position-embedding-type")
}
let layers = (0..cfg.num_hidden_layers)
.map(|index| BertLayer::new(vb.pp(&format!("layer.{index}")), cfg))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?;
Ok(Self {
alibi,
layers,
span,
})
}
}
impl Module for BertEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let seq_len = xs.dim(1)?;
let alibi_bias = self.alibi.i((.., .., ..seq_len, ..seq_len))?;
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, &alibi_bias)?
}
Ok(xs)
}
}
#[derive(Clone, Debug)]
pub struct BertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
pub device: Device,
span: tracing::Span,
}
impl BertModel {
pub fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let embeddings = BertEmbeddings::new(vb.pp("embeddings"), cfg)?;
let encoder = BertEncoder::new(vb.pp("encoder"), cfg)?;
Ok(Self {
embeddings,
encoder,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
}
impl Module for BertModel {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids)?;
let sequence_output = self.encoder.forward(&embedding_output)?;
Ok(sequence_output)
}
}

View File

@ -6,6 +6,7 @@ pub mod convmixer;
pub mod dinov2;
pub mod efficientnet;
pub mod falcon;
pub mod jina_bert;
pub mod llama;
pub mod mistral;
pub mod mixformer;