mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the jina-bert embeddings model. (#1187)
* Add the jina-bert model. * Use alibi. * Remove the unused pragma. * Recompute the alibi embeddings. * Generate the token type ids. * Use the module trait. * Add the jina-bert example. * DType fix. * Get the inference to work.
This commit is contained in:
162
candle-examples/examples/jina-bert/main.rs
Normal file
162
candle-examples/examples/jina-bert/main.rs
Normal file
@ -0,0 +1,162 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_transformers::models::jina_bert::{BertModel, Config};
|
||||
|
||||
use anyhow::Error as E;
|
||||
use candle::{DType, Module, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: String,
|
||||
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let config = Config::v2_base();
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(&self.tokenizer).map_err(E::msg)?;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&self.model], DType::F32, &device)? };
|
||||
let model = BertModel::new(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
} else {
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
"I love pasta",
|
||||
"The new movie is awesome",
|
||||
"The cat plays in the garden",
|
||||
"A woman watches TV",
|
||||
"The new movie is so great",
|
||||
"Do you like pizza?",
|
||||
];
|
||||
let n_sentences = sentences.len();
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = tokenizers::PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
let tokens = tokenizer
|
||||
.encode_batch(sentences.to_vec(), true)
|
||||
.map_err(E::msg)?;
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
println!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = model.forward(&token_ids)?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
let e_i = embeddings.get(i)?;
|
||||
for j in (i + 1)..n_sentences {
|
||||
let e_j = embeddings.get(j)?;
|
||||
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||
similarities.push((cosine_similarity, i, j))
|
||||
}
|
||||
}
|
||||
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||
for &(score, i, j) in similarities[..5].iter() {
|
||||
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
|
||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||
}
|
@ -44,8 +44,10 @@ impl Linear {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Self { weight, bias, span }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
@ -77,8 +79,10 @@ impl LayerNorm {
|
||||
span,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
@ -195,7 +199,9 @@ impl Dropout {
|
||||
fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Dropout {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
@ -316,7 +322,9 @@ impl BertSelfAttention {
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
xs.contiguous()
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertSelfAttention {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query_layer = self.query.forward(hidden_states)?;
|
||||
@ -391,7 +399,9 @@ impl BertAttention {
|
||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertAttention {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
||||
@ -416,7 +426,9 @@ impl BertIntermediate {
|
||||
span: tracing::span!(tracing::Level::TRACE, "inter"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertIntermediate {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
@ -478,7 +490,9 @@ impl BertLayer {
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertLayer {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attention_output = self.attention.forward(hidden_states)?;
|
||||
@ -507,7 +521,9 @@ impl BertEncoder {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(BertEncoder { layers, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertEncoder {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
|
369
candle-transformers/src/models/jina_bert.rs
Normal file
369
candle-transformers/src/models/jina_bert.rs
Normal file
@ -0,0 +1,369 @@
|
||||
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum PositionEmbeddingType {
|
||||
Absolute,
|
||||
Alibi,
|
||||
}
|
||||
|
||||
// https://huggingface.co/jinaai/jina-bert-implementation/blob/main/configuration_bert.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub max_position_embeddings: usize,
|
||||
pub type_vocab_size: usize,
|
||||
pub initializer_range: f64,
|
||||
pub layer_norm_eps: f64,
|
||||
pub pad_token_id: usize,
|
||||
pub position_embedding_type: PositionEmbeddingType,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn v2_base() -> Self {
|
||||
// https://huggingface.co/jinaai/jina-embeddings-v2-base-en/blob/main/config.json
|
||||
Self {
|
||||
vocab_size: 30528,
|
||||
hidden_size: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 3072,
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
max_position_embeddings: 512,
|
||||
type_vocab_size: 2,
|
||||
initializer_range: 0.02,
|
||||
layer_norm_eps: 1e-12,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Alibi,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertEmbeddings {
|
||||
word_embeddings: Embedding,
|
||||
// no position_embeddings as we only support alibi.
|
||||
token_type_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertEmbeddings {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let word_embeddings =
|
||||
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
|
||||
let token_type_embeddings = Embedding::new(
|
||||
cfg.type_vocab_size,
|
||||
cfg.hidden_size,
|
||||
vb.pp("token_type_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertEmbeddings {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = Tensor::zeros(seq_len, DType::U32, input_ids.device())?
|
||||
.broadcast_left(b_size)?
|
||||
.apply(&self.token_type_embeddings)?;
|
||||
let embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertSelfAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
span: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertSelfAttention {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let all_head_size = cfg.num_attention_heads * attention_head_size;
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
|
||||
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
|
||||
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_attention_heads: cfg.num_attention_heads,
|
||||
attention_head_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut x_shape = xs.dims().to_vec();
|
||||
x_shape.pop();
|
||||
x_shape.push(self.num_attention_heads);
|
||||
x_shape.push(self.attention_head_size);
|
||||
xs.reshape(x_shape)?.transpose(1, 2)?.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query_layer = self.query.forward(xs)?;
|
||||
let key_layer = self.key.forward(xs)?;
|
||||
let value_layer = self.value.forward(xs)?;
|
||||
|
||||
let query_layer = self.transpose_for_scores(&query_layer)?;
|
||||
let key_layer = self.transpose_for_scores(&key_layer)?;
|
||||
let value_layer = self.transpose_for_scores(&value_layer)?;
|
||||
|
||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_scores = attention_scores.broadcast_add(bias)?;
|
||||
let attention_probs = {
|
||||
let _enter_sm = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attention_scores)?
|
||||
};
|
||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||
let context_layer = context_layer.flatten_from(D::Minus2)?;
|
||||
Ok(context_layer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertSelfOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertSelfOutput {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-out"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.dense.forward(xs)?;
|
||||
self.layer_norm.forward(&(xs + input_tensor)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertAttention {
|
||||
self_attention: BertSelfAttention,
|
||||
self_output: BertSelfOutput,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertAttention {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = BertSelfAttention::new(vb.pp("self"), cfg)?;
|
||||
let self_output = BertSelfOutput::new(vb.pp("output"), cfg)?;
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
self_output,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let self_outputs = self.self_attention.forward(xs, bias)?;
|
||||
let attention_output = self.self_output.forward(&self_outputs, xs)?;
|
||||
Ok(attention_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertGLUMLP {
|
||||
gated_layers: Linear,
|
||||
act: candle_nn::Activation,
|
||||
wo: Linear,
|
||||
layernorm: LayerNorm,
|
||||
intermediate_size: usize,
|
||||
}
|
||||
|
||||
impl BertGLUMLP {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let gated_layers = linear_no_bias(
|
||||
cfg.hidden_size,
|
||||
cfg.intermediate_size * 2,
|
||||
vb.pp("gated_layers"),
|
||||
)?;
|
||||
let act = candle_nn::Activation::Gelu; // geglu
|
||||
let wo = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("wo"))?;
|
||||
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
|
||||
Ok(Self {
|
||||
gated_layers,
|
||||
act,
|
||||
wo,
|
||||
layernorm,
|
||||
intermediate_size: cfg.intermediate_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertGLUMLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.gated_layers)?;
|
||||
let gated = xs.narrow(D::Minus1, 0, self.intermediate_size)?;
|
||||
let non_gated = xs.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
|
||||
let xs = (gated.apply(&self.act) * non_gated)?.apply(&self.wo);
|
||||
(xs + residual)?.apply(&self.layernorm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertLayer {
|
||||
attention: BertAttention,
|
||||
mlp: BertGLUMLP,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertLayer {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let attention = BertAttention::new(vb.pp("attention"), cfg)?;
|
||||
let mlp = BertGLUMLP::new(vb.pp("mlp"), cfg)?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
mlp,
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.attention.forward(xs, bias)?.apply(&self.mlp)
|
||||
}
|
||||
}
|
||||
|
||||
fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
||||
let n_heads = cfg.num_attention_heads;
|
||||
let seq_len = cfg.max_position_embeddings;
|
||||
let alibi_bias = Tensor::arange(0, seq_len as i64, &Device::Cpu)?.to_dtype(DType::F32)?;
|
||||
let alibi_bias = {
|
||||
let a1 = alibi_bias.reshape((1, seq_len))?;
|
||||
let a2 = alibi_bias.reshape((seq_len, 1))?;
|
||||
a1.broadcast_sub(&a2)?.abs()?.broadcast_left(n_heads)?
|
||||
};
|
||||
let mut n_heads2 = 1;
|
||||
while n_heads2 < n_heads {
|
||||
n_heads2 *= 2
|
||||
}
|
||||
let slopes = (1..=n_heads2)
|
||||
.map(|v| 1f32 / 2f32.powf(8f32 / v as f32))
|
||||
.collect::<Vec<_>>();
|
||||
let slopes = if n_heads2 == n_heads {
|
||||
slopes
|
||||
} else {
|
||||
slopes
|
||||
.iter()
|
||||
.skip(1)
|
||||
.step_by(2)
|
||||
.chain(slopes.iter().step_by(2))
|
||||
.take(n_heads)
|
||||
.cloned()
|
||||
.collect::<Vec<f32>>()
|
||||
};
|
||||
let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
|
||||
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertEncoder {
|
||||
alibi: Tensor,
|
||||
layers: Vec<BertLayer>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertEncoder {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
if cfg.position_embedding_type != PositionEmbeddingType::Alibi {
|
||||
candle::bail!("only alibi is supported as a position-embedding-type")
|
||||
}
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
.map(|index| BertLayer::new(vb.pp(&format!("layer.{index}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?;
|
||||
Ok(Self {
|
||||
alibi,
|
||||
layers,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let seq_len = xs.dim(1)?;
|
||||
let alibi_bias = self.alibi.i((.., .., ..seq_len, ..seq_len))?;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, &alibi_bias)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BertModel {
|
||||
embeddings: BertEmbeddings,
|
||||
encoder: BertEncoder,
|
||||
pub device: Device,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertModel {
|
||||
pub fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let embeddings = BertEmbeddings::new(vb.pp("embeddings"), cfg)?;
|
||||
let encoder = BertEncoder::new(vb.pp("encoder"), cfg)?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
device: vb.device().clone(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertModel {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let embedding_output = self.embeddings.forward(input_ids)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ pub mod convmixer;
|
||||
pub mod dinov2;
|
||||
pub mod efficientnet;
|
||||
pub mod falcon;
|
||||
pub mod jina_bert;
|
||||
pub mod llama;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
|
Reference in New Issue
Block a user