Added XLMRobertaModel for Reranking (#2686)

* add xlm-roberta-base

* Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions

- Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example.
- Updated the logic in `main.rs` to handle tasks using the new enum.
- Enhanced README with example output for fill-mask task.
- Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety.

* Clippy fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Akshay Ballal
2024-12-30 11:16:57 +01:00
committed by GitHub
parent cd639131f0
commit 91f1f019b1
4 changed files with 853 additions and 0 deletions

View File

@ -0,0 +1,30 @@
# candle-xlm-roberta
This example demonstrates how to use the XLM-RoBERTa model in Candle especially known for their use in reranking. It uses the `fill-mask` task to generate a word for a masked token. And a `reranker` task to rerank a list of documents for a given query.
## Usage
Fill Mask:
```bash
cargo run --example xlm-roberta --release -- --task fill-mask --model xlm-roberta-base
```
```markdown
Sentence: 0 : Hello I'm a fashion model.
Sentence: 1 : I'm a little boy.
Sentence: 2 : I'm living in berlin.
```
Reranker:
```bash
cargo run --example xlm-roberta --release -- --task reranker --model bge-reranker-base
```
```markdown
Ranking Results:
--------------------------------------------------------------------------------
> Rank #4 | Score: 0.0001 | South Korea is a country in East Asia.
> Rank #5 | Score: 0.0000 | There are forests in the mountains.
> Rank #2 | Score: 0.7314 | Pandas look like bears.
> Rank #3 | Score: 0.6948 | There are some animals with black and white fur.
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
--------------------------------------------------------------------------------
```

View File

@ -0,0 +1,277 @@
use std::path::PathBuf;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Debug, Clone, ValueEnum)]
enum Model {
BgeRerankerBase,
BgeRerankerLarge,
BgeRerankerBaseV2,
XLMRobertaBase,
XLMRobertaLarge,
}
#[derive(Debug, Clone, ValueEnum)]
enum Task {
FillMask,
Reranker,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long, default_value = "bge-reranker-base")]
model: Model,
#[arg(long, default_value = "reranker")]
task: Task,
// Path to the tokenizer file.
#[arg(long)]
tokenizer_file: Option<String>,
// Path to the weight files.
#[arg(long)]
weight_files: Option<String>,
// Path to the config file.
#[arg(long)]
config_file: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.task {
Task::FillMask => match args.model {
Model::XLMRobertaBase => "FacebookAI/xlm-roberta-base".to_string(),
Model::XLMRobertaLarge => "FacebookAI/xlm-roberta-large".to_string(),
_ => anyhow::bail!("BGE models are not supported for fill-mask task"),
},
Task::Reranker => match args.model {
Model::BgeRerankerBase => "BAAI/bge-reranker-base".to_string(),
Model::BgeRerankerLarge => "BAAI/bge-reranker-large".to_string(),
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
},
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let weights_filename = match args.weight_files {
Some(files) => PathBuf::from(files),
None => match repo.get("model.safetensors") {
Ok(safetensors) => safetensors,
Err(_) => match repo.get("pytorch_model.bin") {
Ok(pytorch_model) => pytorch_model,
Err(e) => {
return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
}
},
},
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let vb = if weights_filename.ends_with("model.safetensors") {
unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F16, &device)
.unwrap()
}
} else {
println!("Loading weights from pytorch_model.bin");
VarBuilder::from_pth(&weights_filename, candle::DType::F16, &device).unwrap()
};
tokenizer
.with_padding(Some(PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
pad_id: config.pad_token_id,
..Default::default()
}))
.with_truncation(None)
.map_err(E::msg)?;
match args.task {
Task::FillMask => {
let prompt = vec![
"Hello I'm a <mask> model.".to_string(),
"I'm a <mask> boy.".to_string(),
"I'm <mask> in berlin.".to_string(),
];
let model = XLMRobertaForMaskedLM::new(&config, vb)?;
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
let attention_mask =
get_attention_mask(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
let output = model
.forward(
&input_ids,
&attention_mask,
&token_type_ids,
None,
None,
None,
)?
.to_dtype(candle::DType::F32)?;
let max_outs = output.argmax(2)?;
let max_out = max_outs.to_vec2::<u32>()?;
let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
for (i, sentence) in decoded.iter().enumerate() {
println!("Sentence: {} : {}", i + 1, sentence);
}
}
Task::Reranker => {
let query = "what is panda?".to_string();
let documents = ["South Korea is a country in East Asia.".to_string(),
"There are forests in the mountains.".to_string(),
"Pandas look like bears.".to_string(),
"There are some animals with black and white fur.".to_string(),
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.".to_string()];
// create pairs of query and documents
let pairs = documents
.iter()
.map(|doc| (query.clone(), doc.clone()))
.collect::<Vec<_>>();
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
let attention_mask =
get_attention_mask(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
let output = model.forward(&input_ids, &attention_mask, &token_type_ids)?;
let output = candle_nn::ops::sigmoid(&output)?.t().unwrap();
let ranks = output
.arg_sort_last_dim(false)?
.to_vec2::<u32>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
println!("\nRanking Results:");
println!("{:-<80}", "");
documents.iter().enumerate().for_each(|(idx, doc)| {
let rank = ranks.iter().position(|&r| r == idx as u32).unwrap();
let score = output
.get_on_dim(1, idx)
.unwrap()
.to_dtype(candle::DType::F32)
.unwrap()
.to_vec1::<f32>()
.unwrap();
println!("Rank #{:<2} | Score: {:.4} | {}", rank + 1, score[0], doc);
});
println!("{:-<80}", "");
}
}
Ok(())
}
#[derive(Debug)]
pub enum TokenizeInput<'a> {
Single(&'a [String]),
Pairs(&'a [(String, String)]),
}
pub fn tokenize_batch(
tokenizer: &Tokenizer,
input: TokenizeInput,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = match input {
TokenizeInput::Single(text_batch) => tokenizer
.encode_batch(text_batch.to_vec(), true)
.map_err(E::msg)?,
TokenizeInput::Pairs(pairs) => tokenizer
.encode_batch(pairs.to_vec(), true)
.map_err(E::msg)?,
};
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
Ok(Tensor::stack(&token_ids, 0)?)
}
pub fn get_attention_mask(
tokenizer: &Tokenizer,
input: TokenizeInput,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = match input {
TokenizeInput::Single(text_batch) => tokenizer
.encode_batch(text_batch.to_vec(), true)
.map_err(E::msg)?,
TokenizeInput::Pairs(pairs) => tokenizer
.encode_batch(pairs.to_vec(), true)
.map_err(E::msg)?,
};
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
Ok(Tensor::stack(&attention_mask, 0)?)
}

View File

@ -109,4 +109,5 @@ pub mod vit;
pub mod whisper;
pub mod with_tracing;
pub mod wuerstchen;
pub mod xlm_roberta;
pub mod yi;

View File

@ -0,0 +1,545 @@
use crate::models::with_tracing::{linear, Linear};
use candle::{DType, Module, Result, Tensor};
use candle_nn::{
embedding, layer_norm, ops::softmax_last_dim, Activation, Embedding, LayerNorm, VarBuilder,
};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub hidden_size: usize,
pub layer_norm_eps: f64,
pub attention_probs_dropout_prob: f32,
pub hidden_dropout_prob: f32,
pub num_attention_heads: usize,
pub position_embedding_type: String,
pub intermediate_size: usize,
pub hidden_act: Activation,
pub num_hidden_layers: usize,
pub vocab_size: usize,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub pad_token_id: u32,
}
struct XLMRobertaEmbeddings {
word_embeddings: Embedding,
position_embeddings: Option<Embedding>,
token_type_embeddings: Embedding,
layer_norm: LayerNorm,
padding_idx: u32,
span: tracing::Span,
}
impl XLMRobertaEmbeddings {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings = embedding(
config.vocab_size,
config.hidden_size,
vb.pp("word_embeddings"),
)?;
let position_embeddings = embedding(
config.max_position_embeddings,
config.hidden_size,
vb.pp("position_embeddings"),
)?;
let token_type_embeddings = embedding(
config.type_vocab_size,
config.hidden_size,
vb.pp("token_type_embeddings"),
)?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
vb.pp("LayerNorm"),
)?;
Ok(Self {
word_embeddings,
position_embeddings: Some(position_embeddings),
token_type_embeddings,
layer_norm,
padding_idx: config.pad_token_id,
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
})
}
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_bsize, _) = input_ids.dims2()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
if let Some(position_embeddings) = &self.position_embeddings {
let mask = input_ids
.ne(self.padding_idx)?
.to_dtype(input_embeddings.dtype())?;
let cumsum = mask.cumsum(1)?;
let position_ids = (cumsum * mask)?
.broadcast_add(
&Tensor::try_from(self.padding_idx)?
.to_dtype(input_embeddings.dtype())?
.to_device(input_embeddings.device())?,
)?
.to_dtype(candle::DType::U32)?;
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?;
}
let embeddings = self.layer_norm.forward(&embeddings)?;
Ok(embeddings)
}
}
struct XLMRobertaSelfAttention {
num_attention_heads: usize,
attention_head_size: usize,
all_head_size: usize,
query: Linear,
key: Linear,
value: Linear,
}
impl XLMRobertaSelfAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
let all_head_size = cfg.num_attention_heads * attention_head_size;
Ok(Self {
num_attention_heads: cfg.num_attention_heads,
attention_head_size,
all_head_size,
query: linear(cfg.hidden_size, all_head_size, vb.pp("query"))?,
key: linear(cfg.hidden_size, all_head_size, vb.pp("key"))?,
value: linear(cfg.hidden_size, all_head_size, vb.pp("value"))?,
})
}
fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
let mut new_x_shape = x.dims().to_vec();
new_x_shape[2] = self.num_attention_heads;
new_x_shape.push(self.attention_head_size);
let x = x.reshape(new_x_shape)?;
x.permute((0, 2, 1, 3))?.contiguous()
}
fn forward(
&self,
hidden_states: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: &Tensor,
past_key_value: Option<(&Tensor, &Tensor)>,
encoder_attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let mixed_query_layer = self.query.forward(hidden_states)?;
let is_cross_attention = encoder_hidden_states.is_some();
let (key_layer, value_layer, attention_mask) = if is_cross_attention
&& past_key_value.is_some()
{
let key_layer = past_key_value.unwrap().0.clone();
let value_layer = past_key_value.unwrap().1.clone();
let attention_mask = encoder_attention_mask.unwrap().clone();
(key_layer, value_layer, Some(attention_mask))
} else if is_cross_attention {
let key_layer =
self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?;
let value_layer =
self.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?;
let attention_mask = encoder_attention_mask.unwrap();
(key_layer, value_layer, Some(attention_mask.clone()))
} else if past_key_value.is_some() {
let mut key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
let mut value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
key_layer = Tensor::cat(
&[
past_key_value.clone().as_ref().unwrap().0.clone(),
key_layer,
],
2,
)?;
value_layer = Tensor::cat(
&[past_key_value.as_ref().unwrap().1.clone(), value_layer],
2,
)?;
(key_layer, value_layer, Some(attention_mask.clone()))
} else {
let key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
let value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
(key_layer, value_layer, Some(attention_mask.clone()))
};
let query_layer = self.transpose_for_scores(&mixed_query_layer)?;
let mut attention_scores = query_layer.matmul(&key_layer.transpose(2, 3)?)?;
let scale = 1f64 / f64::sqrt(self.attention_head_size as f64);
attention_scores = (attention_scores * scale)?;
attention_scores = match attention_mask {
None => attention_scores,
Some(mask) => {
attention_scores.broadcast_add(&mask.to_dtype(attention_scores.dtype())?)?
}
};
let attention_probs = softmax_last_dim(&attention_scores)?;
let context_layer = attention_probs
.matmul(&value_layer)?
.permute((0, 2, 1, 3))?
.contiguous()?;
let mut new_context_layer_shape =
context_layer.dims()[..context_layer.dims().len() - 2].to_vec();
new_context_layer_shape.push(self.all_head_size);
let context_layer = context_layer.reshape(new_context_layer_shape)?;
Ok(context_layer)
}
}
struct XLMRobertaSelfOutput {
dense: Linear,
layernorm: LayerNorm,
}
impl XLMRobertaSelfOutput {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
let layernorm =
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
Ok(Self { dense, layernorm })
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
Ok(hidden_states)
}
}
struct XLMRobertaAttention {
output: XLMRobertaSelfOutput,
self_attention: XLMRobertaSelfAttention,
}
impl XLMRobertaAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let output = XLMRobertaSelfOutput::new(cfg, vb.pp("output"))?;
let self_attention = XLMRobertaSelfAttention::new(cfg, vb.pp("self"))?;
Ok(Self {
output,
self_attention,
})
}
fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
past_key_value: Option<(&Tensor, &Tensor)>,
) -> Result<(Tensor, Tensor)> {
let self_outputs = self.self_attention.forward(
hidden_states,
encoder_hidden_states,
attention_mask,
past_key_value,
encoder_attention_mask,
)?;
let attention_output = self.output.forward(&self_outputs, hidden_states)?;
Ok((attention_output, self_outputs))
}
}
struct XLMRobertaOutput {
dense: Linear,
layernorm: LayerNorm,
}
impl XLMRobertaOutput {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
let layernorm =
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
Ok(Self { dense, layernorm })
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
Ok(hidden_states)
}
}
struct XLMRobertaIntermediate {
dense: Linear,
intermediate_act_fn: Activation,
}
impl XLMRobertaIntermediate {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
let intermediate_act_fn = cfg.hidden_act;
Ok(Self {
dense,
intermediate_act_fn,
})
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.intermediate_act_fn.forward(&hidden_states)?;
Ok(hidden_states)
}
}
struct XLMRobertaLayer {
attention: XLMRobertaAttention,
intermediate: XLMRobertaIntermediate,
output: XLMRobertaOutput,
}
impl XLMRobertaLayer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention = XLMRobertaAttention::new(cfg, vb.pp("attention"))?;
let intermediate = XLMRobertaIntermediate::new(cfg, vb.pp("intermediate"))?;
let output = XLMRobertaOutput::new(cfg, vb.pp("output"))?;
Ok(Self {
attention,
intermediate,
output,
})
}
fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
past_key_value: Option<(&Tensor, &Tensor)>,
) -> Result<(Tensor, Tensor)> {
let self_attention_outputs = self.attention.forward(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
)?;
let attention_output = self_attention_outputs.0;
let outputs = self_attention_outputs.1;
let intermediate_output = self.intermediate.forward(&attention_output)?;
let layer_output = self
.output
.forward(&intermediate_output, &attention_output)?;
Ok((layer_output, outputs))
}
}
struct XLMRobertaEncoder {
layers: Vec<XLMRobertaLayer>,
}
impl XLMRobertaEncoder {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let layers = (0..cfg.num_hidden_layers)
.map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{}", i))))
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
}
fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
past_key_value: Option<(&Tensor, &Tensor)>,
) -> Result<Tensor> {
let mut hidden_states = hidden_states.clone();
for layer_module in self.layers.iter() {
let layer_outputs = layer_module.forward(
&hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
)?;
hidden_states = layer_outputs.0;
}
Ok(hidden_states)
}
}
pub struct XLMRobertaModel {
encoder: XLMRobertaEncoder,
embeddings: XLMRobertaEmbeddings,
}
impl XLMRobertaModel {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = XLMRobertaEncoder::new(cfg, vb.pp("encoder"))?;
let embeddings = XLMRobertaEmbeddings::load(vb.pp("embeddings"), cfg)?;
Ok(Self {
encoder,
embeddings,
})
}
pub fn forward(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
token_type_ids: &Tensor,
past_key_value: Option<(&Tensor, &Tensor)>,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let hidden_states = self.embeddings.forward(input_ids, token_type_ids)?;
let attention_mask = prepare_4d_attention_mask(attention_mask, DType::F32, None)?
.to_device(hidden_states.device())?;
let hidden_states = self.encoder.forward(
&hidden_states,
&attention_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
)?;
Ok(hidden_states)
}
}
struct XLMRobertaLMHead {
dense: Linear,
layer_norm: LayerNorm,
}
impl XLMRobertaLMHead {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
let layer_norm =
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layer_norm"))?;
Ok(Self { dense, layer_norm })
}
fn forward(&self, hidden_states: &Tensor, shared_embeddings: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = candle_nn::Activation::Gelu.forward(&hidden_states)?;
let hidden_states = self.layer_norm.forward(&hidden_states)?;
let hidden_states = hidden_states.broadcast_matmul(shared_embeddings)?;
Ok(hidden_states)
}
}
pub struct XLMRobertaForMaskedLM {
roberta: XLMRobertaModel,
lm_head: XLMRobertaLMHead,
}
impl XLMRobertaForMaskedLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
let lm_head = XLMRobertaLMHead::new(cfg, vb.pp("lm_head"))?;
Ok(Self { roberta, lm_head })
}
pub fn forward(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
token_type_ids: &Tensor,
past_key_value: Option<(&Tensor, &Tensor)>,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let hidden_states = self.roberta.forward(
input_ids,
attention_mask,
token_type_ids,
past_key_value,
encoder_hidden_states,
encoder_attention_mask,
)?;
let lm_logits = self.lm_head.forward(
&hidden_states,
&self
.roberta
.embeddings
.word_embeddings
.embeddings()
.t()?
.unsqueeze(0)?,
)?;
Ok(lm_logits)
}
}
struct XLMRobertaClassificationHead {
dense: Linear,
out_proj: Linear,
}
impl XLMRobertaClassificationHead {
fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
let out_proj = linear(cfg.hidden_size, num_labels, vb.pp("out_proj"))?;
Ok(Self { dense, out_proj })
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
let hidden_states = self.dense.forward(&cls_states)?;
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
let hidden_states = self.out_proj.forward(&hidden_states)?;
Ok(hidden_states)
}
}
pub struct XLMRobertaForSequenceClassification {
roberta: XLMRobertaModel,
classifier: XLMRobertaClassificationHead,
}
impl XLMRobertaForSequenceClassification {
pub fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
let classifier = XLMRobertaClassificationHead::new(num_labels, cfg, vb.pp("classifier"))?;
Ok(Self {
roberta,
classifier,
})
}
pub fn forward(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
token_type_ids: &Tensor,
) -> Result<Tensor> {
let hidden_states =
self.roberta
.forward(input_ids, attention_mask, token_type_ids, None, None, None)?;
self.classifier.forward(&hidden_states)
}
}
fn prepare_4d_attention_mask(
mask: &Tensor,
dtype: DType,
tgt_len: Option<usize>,
) -> Result<Tensor> {
let bsz = mask.dim(0)?;
let src_len = mask.dim(1)?;
let tgt_len = tgt_len.unwrap_or(src_len);
let expanded_mask = mask
.unsqueeze(1)?
.unsqueeze(2)?
.expand((bsz, 1, tgt_len, src_len))?
.to_dtype(dtype)?;
let inverted_mask = (1.0 - expanded_mask)?;
(inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
}
fn get_dtype_min_val(dtype: DType) -> f64 {
match dtype {
DType::F32 => f32::MIN as f64,
DType::F64 => f64::MIN,
_ => panic!("Unsupported data type"),
}
}