mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add BertForMaskedLM to support SPLADE Models (#2550)
* add bert for masked lm * working example * add example readme * Clippy fix. * And apply rustfmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
28
candle-examples/examples/splade/README.md
Normal file
28
candle-examples/examples/splade/README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# candle-splade
|
||||||
|
|
||||||
|
SPLADE is a neural retrieval model which learns query/document sparse expansion via the BERT MLM head and sparse regularization. Sparse representations benefit from several advantages compared to dense approaches: efficient use of inverted index, explicit lexical match, interpretability... They also seem to be better at generalizing on out-of-domain data. In this example we can do the following two tasks:
|
||||||
|
|
||||||
|
- Compute sparse embedding for a given query.
|
||||||
|
- Compute similarities between a set of sentences using sparse embeddings.
|
||||||
|
|
||||||
|
## Sparse Sentence embeddings
|
||||||
|
|
||||||
|
SPLADE is used to compute the sparse embedding for a given query. The model weights
|
||||||
|
are downloaded from the hub on the first run. This makes use of the BertForMaskedLM model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example splade --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
|
> "the out there still house inside position outside stay standing hotel sitting dog animal sit bird cat statue cats"
|
||||||
|
> [0.10270107, 0.269471, 0.047469813, 0.0016636598, 0.05394874, 0.23105666, 0.037475716, 0.45949644, 0.009062732, 0.06790692, 0.0327835, 0.33122346, 0.16863061, 0.12688516, 0.340983, 0.044972017, 0.47724655, 0.01765311, 0.37331146]
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example splade --release --features
|
||||||
|
|
||||||
|
> score: 0.47 'The new movie is awesome' 'The new movie is so great'
|
||||||
|
> score: 0.43 'The cat sits outside' 'The cat plays in the garden'
|
||||||
|
> score: 0.14 'I love pasta' 'Do you like pizza?'
|
||||||
|
> score: 0.11 'A man is playing guitar' 'The cat plays in the garden'
|
||||||
|
> score: 0.05 'A man is playing guitar' 'A woman watches TV'
|
||||||
|
```
|
210
candle-examples/examples/splade/main.rs
Normal file
210
candle-examples/examples/splade/main.rs
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle::Tensor;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::bert::{self, BertForMaskedLM, Config};
|
||||||
|
use clap::Parser;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
// Path to the tokenizer file.
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
// Path to the weight files.
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
// Path to the config file.
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
/// When set, compute embeddings for this prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match &args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => "prithivida/Splade_PP_en_v1".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let config_filename = match args.config_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("config.json")?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let weights_filename = match args.weight_files {
|
||||||
|
Some(files) => PathBuf::from(files),
|
||||||
|
None => match repo.get("model.safetensors") {
|
||||||
|
Ok(safetensors) => safetensors,
|
||||||
|
Err(_) => match repo.get("pytorch_model.bin") {
|
||||||
|
Ok(pytorch_model) => pytorch_model,
|
||||||
|
Err(e) => {
|
||||||
|
return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
|
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = bert::DTYPE;
|
||||||
|
|
||||||
|
let vb = if weights_filename.ends_with("model.safetensors") {
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], dtype, &device).unwrap() }
|
||||||
|
} else {
|
||||||
|
println!("Loading weights from pytorch_model.bin");
|
||||||
|
VarBuilder::from_pth(&weights_filename, dtype, &device).unwrap()
|
||||||
|
};
|
||||||
|
let model = BertForMaskedLM::load(vb, &config)?;
|
||||||
|
|
||||||
|
if let Some(prompt) = args.prompt {
|
||||||
|
let tokenizer = tokenizer
|
||||||
|
.with_padding(None)
|
||||||
|
.with_truncation(None)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||||
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
|
|
||||||
|
let ys = model.forward(&token_ids, &token_type_ids, None)?;
|
||||||
|
let vec = Tensor::log(
|
||||||
|
&Tensor::try_from(1.0)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.to_device(&device)?
|
||||||
|
.broadcast_add(&ys.relu()?)?,
|
||||||
|
)?
|
||||||
|
.max(1)?;
|
||||||
|
let vec = normalize_l2(&vec)?;
|
||||||
|
|
||||||
|
let vec = vec.squeeze(0)?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
let indices = (0..vec.len())
|
||||||
|
.filter(|&i| vec[i] != 0.0)
|
||||||
|
.map(|x| x as u32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let tokens = tokenizer.decode(&indices, true).unwrap();
|
||||||
|
println!("{tokens:?}");
|
||||||
|
let values = indices.iter().map(|&i| vec[i as usize]).collect::<Vec<_>>();
|
||||||
|
println!("{values:?}");
|
||||||
|
} else {
|
||||||
|
let sentences = [
|
||||||
|
"The cat sits outside",
|
||||||
|
"A man is playing guitar",
|
||||||
|
"I love pasta",
|
||||||
|
"The new movie is awesome",
|
||||||
|
"The cat plays in the garden",
|
||||||
|
"A woman watches TV",
|
||||||
|
"The new movie is so great",
|
||||||
|
"Do you like pizza?",
|
||||||
|
];
|
||||||
|
|
||||||
|
let n_sentences = sentences.len();
|
||||||
|
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||||
|
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||||
|
} else {
|
||||||
|
let pp = PaddingParams {
|
||||||
|
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
tokenizer.with_padding(Some(pp));
|
||||||
|
}
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode_batch(sentences.to_vec(), true)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let token_ids = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_ids().to_vec();
|
||||||
|
Ok(Tensor::new(tokens.as_slice(), &device)?)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let attention_mask = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_attention_mask().to_vec();
|
||||||
|
Ok(Tensor::new(tokens.as_slice(), &device)?)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||||
|
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||||
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
|
|
||||||
|
let ys = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||||
|
let vector = Tensor::log(
|
||||||
|
&Tensor::try_from(1.0)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.to_device(&device)?
|
||||||
|
.broadcast_add(&ys.relu()?)?,
|
||||||
|
)?;
|
||||||
|
let vector = vector
|
||||||
|
.broadcast_mul(&attention_mask.unsqueeze(2)?.to_dtype(dtype)?)?
|
||||||
|
.max(1)?;
|
||||||
|
let vec = normalize_l2(&vector)?;
|
||||||
|
let mut similarities = vec![];
|
||||||
|
for i in 0..n_sentences {
|
||||||
|
let e_i = vec.get(i)?;
|
||||||
|
for j in (i + 1)..n_sentences {
|
||||||
|
let e_j = vec.get(j)?;
|
||||||
|
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||||
|
similarities.push((cosine_similarity, i, j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||||
|
for &(score, i, j) in similarities[..5].iter() {
|
||||||
|
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||||
|
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||||
|
}
|
@ -504,3 +504,100 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
|||||||
(attention_mask.ones_like()? - &attention_mask)?
|
(attention_mask.ones_like()? - &attention_mask)?
|
||||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
||||||
|
struct BertPredictionHeadTransform {
|
||||||
|
dense: Linear,
|
||||||
|
activation: HiddenActLayer,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertPredictionHeadTransform {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
|
||||||
|
let activation = HiddenActLayer::new(config.hidden_act);
|
||||||
|
let layer_norm = layer_norm(
|
||||||
|
config.hidden_size,
|
||||||
|
config.layer_norm_eps,
|
||||||
|
vb.pp("LayerNorm"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
activation,
|
||||||
|
layer_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for BertPredictionHeadTransform {
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
let hidden_states = self
|
||||||
|
.activation
|
||||||
|
.forward(&self.dense.forward(hidden_states)?)?;
|
||||||
|
self.layer_norm.forward(&hidden_states)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
|
||||||
|
pub struct BertLMPredictionHead {
|
||||||
|
transform: BertPredictionHeadTransform,
|
||||||
|
decoder: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertLMPredictionHead {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?;
|
||||||
|
let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?;
|
||||||
|
Ok(Self { transform, decoder })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for BertLMPredictionHead {
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
self.decoder
|
||||||
|
.forward(&self.transform.forward(hidden_states)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
|
||||||
|
pub struct BertOnlyMLMHead {
|
||||||
|
predictions: BertLMPredictionHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertOnlyMLMHead {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?;
|
||||||
|
Ok(Self { predictions })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for BertOnlyMLMHead {
|
||||||
|
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
|
||||||
|
self.predictions.forward(sequence_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct BertForMaskedLM {
|
||||||
|
bert: BertModel,
|
||||||
|
cls: BertOnlyMLMHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertForMaskedLM {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let bert = BertModel::load(vb.pp("bert"), config)?;
|
||||||
|
let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?;
|
||||||
|
Ok(Self { bert, cls })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&self,
|
||||||
|
input_ids: &Tensor,
|
||||||
|
token_type_ids: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let sequence_output = self
|
||||||
|
.bert
|
||||||
|
.forward(input_ids, token_type_ids, attention_mask)?;
|
||||||
|
self.cls.forward(&sequence_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user