mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Added XLMRobertaModel for Reranking (#2686)
* add xlm-roberta-base * Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions - Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example. - Updated the logic in `main.rs` to handle tasks using the new enum. - Enhanced README with example output for fill-mask task. - Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety. * Clippy fix. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
30
candle-examples/examples/xlm-roberta/Readme.md
Normal file
30
candle-examples/examples/xlm-roberta/Readme.md
Normal file
@ -0,0 +1,30 @@
|
||||
# candle-xlm-roberta
|
||||
|
||||
This example demonstrates how to use the XLM-RoBERTa model in Candle especially known for their use in reranking. It uses the `fill-mask` task to generate a word for a masked token. And a `reranker` task to rerank a list of documents for a given query.
|
||||
|
||||
## Usage
|
||||
|
||||
Fill Mask:
|
||||
```bash
|
||||
cargo run --example xlm-roberta --release -- --task fill-mask --model xlm-roberta-base
|
||||
```
|
||||
```markdown
|
||||
Sentence: 0 : Hello I'm a fashion model.
|
||||
Sentence: 1 : I'm a little boy.
|
||||
Sentence: 2 : I'm living in berlin.
|
||||
```
|
||||
|
||||
Reranker:
|
||||
```bash
|
||||
cargo run --example xlm-roberta --release -- --task reranker --model bge-reranker-base
|
||||
```
|
||||
```markdown
|
||||
Ranking Results:
|
||||
--------------------------------------------------------------------------------
|
||||
> Rank #4 | Score: 0.0001 | South Korea is a country in East Asia.
|
||||
> Rank #5 | Score: 0.0000 | There are forests in the mountains.
|
||||
> Rank #2 | Score: 0.7314 | Pandas look like bears.
|
||||
> Rank #3 | Score: 0.6948 | There are some animals with black and white fur.
|
||||
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
|
||||
--------------------------------------------------------------------------------
|
||||
```
|
277
candle-examples/examples/xlm-roberta/main.rs
Normal file
277
candle-examples/examples/xlm-roberta/main.rs
Normal file
@ -0,0 +1,277 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::xlm_roberta::{
|
||||
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
|
||||
};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
enum Model {
|
||||
BgeRerankerBase,
|
||||
BgeRerankerLarge,
|
||||
BgeRerankerBaseV2,
|
||||
XLMRobertaBase,
|
||||
XLMRobertaLarge,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
enum Task {
|
||||
FillMask,
|
||||
Reranker,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long, default_value = "bge-reranker-base")]
|
||||
model: Model,
|
||||
|
||||
#[arg(long, default_value = "reranker")]
|
||||
task: Task,
|
||||
|
||||
// Path to the tokenizer file.
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
// Path to the weight files.
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
// Path to the config file.
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.task {
|
||||
Task::FillMask => match args.model {
|
||||
Model::XLMRobertaBase => "FacebookAI/xlm-roberta-base".to_string(),
|
||||
Model::XLMRobertaLarge => "FacebookAI/xlm-roberta-large".to_string(),
|
||||
_ => anyhow::bail!("BGE models are not supported for fill-mask task"),
|
||||
},
|
||||
Task::Reranker => match args.model {
|
||||
Model::BgeRerankerBase => "BAAI/bge-reranker-base".to_string(),
|
||||
Model::BgeRerankerLarge => "BAAI/bge-reranker-large".to_string(),
|
||||
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
|
||||
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
|
||||
},
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
|
||||
let weights_filename = match args.weight_files {
|
||||
Some(files) => PathBuf::from(files),
|
||||
None => match repo.get("model.safetensors") {
|
||||
Ok(safetensors) => safetensors,
|
||||
Err(_) => match repo.get("pytorch_model.bin") {
|
||||
Ok(pytorch_model) => pytorch_model,
|
||||
Err(e) => {
|
||||
return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
|
||||
}
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vb = if weights_filename.ends_with("model.safetensors") {
|
||||
unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F16, &device)
|
||||
.unwrap()
|
||||
}
|
||||
} else {
|
||||
println!("Loading weights from pytorch_model.bin");
|
||||
VarBuilder::from_pth(&weights_filename, candle::DType::F16, &device).unwrap()
|
||||
};
|
||||
tokenizer
|
||||
.with_padding(Some(PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
pad_id: config.pad_token_id,
|
||||
..Default::default()
|
||||
}))
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
|
||||
match args.task {
|
||||
Task::FillMask => {
|
||||
let prompt = vec![
|
||||
"Hello I'm a <mask> model.".to_string(),
|
||||
"I'm a <mask> boy.".to_string(),
|
||||
"I'm <mask> in berlin.".to_string(),
|
||||
];
|
||||
let model = XLMRobertaForMaskedLM::new(&config, vb)?;
|
||||
|
||||
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
|
||||
let attention_mask =
|
||||
get_attention_mask(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
|
||||
|
||||
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
|
||||
|
||||
let output = model
|
||||
.forward(
|
||||
&input_ids,
|
||||
&attention_mask,
|
||||
&token_type_ids,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)?
|
||||
.to_dtype(candle::DType::F32)?;
|
||||
|
||||
let max_outs = output.argmax(2)?;
|
||||
|
||||
let max_out = max_outs.to_vec2::<u32>()?;
|
||||
let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
|
||||
let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
|
||||
for (i, sentence) in decoded.iter().enumerate() {
|
||||
println!("Sentence: {} : {}", i + 1, sentence);
|
||||
}
|
||||
}
|
||||
Task::Reranker => {
|
||||
let query = "what is panda?".to_string();
|
||||
|
||||
let documents = ["South Korea is a country in East Asia.".to_string(),
|
||||
"There are forests in the mountains.".to_string(),
|
||||
"Pandas look like bears.".to_string(),
|
||||
"There are some animals with black and white fur.".to_string(),
|
||||
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.".to_string()];
|
||||
|
||||
// create pairs of query and documents
|
||||
let pairs = documents
|
||||
.iter()
|
||||
.map(|doc| (query.clone(), doc.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
|
||||
let attention_mask =
|
||||
get_attention_mask(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
|
||||
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
|
||||
|
||||
let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
|
||||
|
||||
let output = model.forward(&input_ids, &attention_mask, &token_type_ids)?;
|
||||
let output = candle_nn::ops::sigmoid(&output)?.t().unwrap();
|
||||
let ranks = output
|
||||
.arg_sort_last_dim(false)?
|
||||
.to_vec2::<u32>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
println!("\nRanking Results:");
|
||||
println!("{:-<80}", "");
|
||||
documents.iter().enumerate().for_each(|(idx, doc)| {
|
||||
let rank = ranks.iter().position(|&r| r == idx as u32).unwrap();
|
||||
let score = output
|
||||
.get_on_dim(1, idx)
|
||||
.unwrap()
|
||||
.to_dtype(candle::DType::F32)
|
||||
.unwrap()
|
||||
.to_vec1::<f32>()
|
||||
.unwrap();
|
||||
println!("Rank #{:<2} | Score: {:.4} | {}", rank + 1, score[0], doc);
|
||||
});
|
||||
println!("{:-<80}", "");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TokenizeInput<'a> {
|
||||
Single(&'a [String]),
|
||||
Pairs(&'a [(String, String)]),
|
||||
}
|
||||
|
||||
pub fn tokenize_batch(
|
||||
tokenizer: &Tokenizer,
|
||||
input: TokenizeInput,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let tokens = match input {
|
||||
TokenizeInput::Single(text_batch) => tokenizer
|
||||
.encode_batch(text_batch.to_vec(), true)
|
||||
.map_err(E::msg)?,
|
||||
TokenizeInput::Pairs(pairs) => tokenizer
|
||||
.encode_batch(pairs.to_vec(), true)
|
||||
.map_err(E::msg)?,
|
||||
};
|
||||
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
|
||||
Ok(Tensor::stack(&token_ids, 0)?)
|
||||
}
|
||||
|
||||
pub fn get_attention_mask(
|
||||
tokenizer: &Tokenizer,
|
||||
input: TokenizeInput,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let tokens = match input {
|
||||
TokenizeInput::Single(text_batch) => tokenizer
|
||||
.encode_batch(text_batch.to_vec(), true)
|
||||
.map_err(E::msg)?,
|
||||
TokenizeInput::Pairs(pairs) => tokenizer
|
||||
.encode_batch(pairs.to_vec(), true)
|
||||
.map_err(E::msg)?,
|
||||
};
|
||||
|
||||
let attention_mask = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_attention_mask().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
Ok(Tensor::stack(&attention_mask, 0)?)
|
||||
}
|
@ -109,4 +109,5 @@ pub mod vit;
|
||||
pub mod whisper;
|
||||
pub mod with_tracing;
|
||||
pub mod wuerstchen;
|
||||
pub mod xlm_roberta;
|
||||
pub mod yi;
|
||||
|
545
candle-transformers/src/models/xlm_roberta.rs
Normal file
545
candle-transformers/src/models/xlm_roberta.rs
Normal file
@ -0,0 +1,545 @@
|
||||
use crate::models::with_tracing::{linear, Linear};
|
||||
use candle::{DType, Module, Result, Tensor};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, ops::softmax_last_dim, Activation, Embedding, LayerNorm, VarBuilder,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub layer_norm_eps: f64,
|
||||
pub attention_probs_dropout_prob: f32,
|
||||
pub hidden_dropout_prob: f32,
|
||||
pub num_attention_heads: usize,
|
||||
pub position_embedding_type: String,
|
||||
pub intermediate_size: usize,
|
||||
pub hidden_act: Activation,
|
||||
pub num_hidden_layers: usize,
|
||||
pub vocab_size: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub type_vocab_size: usize,
|
||||
pub pad_token_id: u32,
|
||||
}
|
||||
|
||||
struct XLMRobertaEmbeddings {
|
||||
word_embeddings: Embedding,
|
||||
position_embeddings: Option<Embedding>,
|
||||
token_type_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
padding_idx: u32,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl XLMRobertaEmbeddings {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings = embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
vb.pp("word_embeddings"),
|
||||
)?;
|
||||
let position_embeddings = embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
vb.pp("position_embeddings"),
|
||||
)?;
|
||||
let token_type_embeddings = embedding(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
vb.pp("token_type_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("LayerNorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
position_embeddings: Some(position_embeddings),
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
padding_idx: config.pad_token_id,
|
||||
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_bsize, _) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
if let Some(position_embeddings) = &self.position_embeddings {
|
||||
let mask = input_ids
|
||||
.ne(self.padding_idx)?
|
||||
.to_dtype(input_embeddings.dtype())?;
|
||||
let cumsum = mask.cumsum(1)?;
|
||||
let position_ids = (cumsum * mask)?
|
||||
.broadcast_add(
|
||||
&Tensor::try_from(self.padding_idx)?
|
||||
.to_dtype(input_embeddings.dtype())?
|
||||
.to_device(input_embeddings.device())?,
|
||||
)?
|
||||
.to_dtype(candle::DType::U32)?;
|
||||
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?;
|
||||
}
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaSelfAttention {
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
all_head_size: usize,
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
}
|
||||
|
||||
impl XLMRobertaSelfAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let all_head_size = cfg.num_attention_heads * attention_head_size;
|
||||
Ok(Self {
|
||||
num_attention_heads: cfg.num_attention_heads,
|
||||
attention_head_size,
|
||||
all_head_size,
|
||||
query: linear(cfg.hidden_size, all_head_size, vb.pp("query"))?,
|
||||
key: linear(cfg.hidden_size, all_head_size, vb.pp("key"))?,
|
||||
value: linear(cfg.hidden_size, all_head_size, vb.pp("value"))?,
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let mut new_x_shape = x.dims().to_vec();
|
||||
new_x_shape[2] = self.num_attention_heads;
|
||||
new_x_shape.push(self.attention_head_size);
|
||||
let x = x.reshape(new_x_shape)?;
|
||||
x.permute((0, 2, 1, 3))?.contiguous()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
attention_mask: &Tensor,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let mixed_query_layer = self.query.forward(hidden_states)?;
|
||||
let is_cross_attention = encoder_hidden_states.is_some();
|
||||
let (key_layer, value_layer, attention_mask) = if is_cross_attention
|
||||
&& past_key_value.is_some()
|
||||
{
|
||||
let key_layer = past_key_value.unwrap().0.clone();
|
||||
let value_layer = past_key_value.unwrap().1.clone();
|
||||
let attention_mask = encoder_attention_mask.unwrap().clone();
|
||||
(key_layer, value_layer, Some(attention_mask))
|
||||
} else if is_cross_attention {
|
||||
let key_layer =
|
||||
self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?;
|
||||
let value_layer =
|
||||
self.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?;
|
||||
let attention_mask = encoder_attention_mask.unwrap();
|
||||
(key_layer, value_layer, Some(attention_mask.clone()))
|
||||
} else if past_key_value.is_some() {
|
||||
let mut key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
|
||||
let mut value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
|
||||
key_layer = Tensor::cat(
|
||||
&[
|
||||
past_key_value.clone().as_ref().unwrap().0.clone(),
|
||||
key_layer,
|
||||
],
|
||||
2,
|
||||
)?;
|
||||
value_layer = Tensor::cat(
|
||||
&[past_key_value.as_ref().unwrap().1.clone(), value_layer],
|
||||
2,
|
||||
)?;
|
||||
(key_layer, value_layer, Some(attention_mask.clone()))
|
||||
} else {
|
||||
let key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
|
||||
let value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
|
||||
(key_layer, value_layer, Some(attention_mask.clone()))
|
||||
};
|
||||
|
||||
let query_layer = self.transpose_for_scores(&mixed_query_layer)?;
|
||||
let mut attention_scores = query_layer.matmul(&key_layer.transpose(2, 3)?)?;
|
||||
let scale = 1f64 / f64::sqrt(self.attention_head_size as f64);
|
||||
|
||||
attention_scores = (attention_scores * scale)?;
|
||||
attention_scores = match attention_mask {
|
||||
None => attention_scores,
|
||||
Some(mask) => {
|
||||
attention_scores.broadcast_add(&mask.to_dtype(attention_scores.dtype())?)?
|
||||
}
|
||||
};
|
||||
let attention_probs = softmax_last_dim(&attention_scores)?;
|
||||
|
||||
let context_layer = attention_probs
|
||||
.matmul(&value_layer)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.contiguous()?;
|
||||
let mut new_context_layer_shape =
|
||||
context_layer.dims()[..context_layer.dims().len() - 2].to_vec();
|
||||
new_context_layer_shape.push(self.all_head_size);
|
||||
let context_layer = context_layer.reshape(new_context_layer_shape)?;
|
||||
|
||||
Ok(context_layer)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaSelfOutput {
|
||||
dense: Linear,
|
||||
layernorm: LayerNorm,
|
||||
}
|
||||
|
||||
impl XLMRobertaSelfOutput {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layernorm =
|
||||
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self { dense, layernorm })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaAttention {
|
||||
output: XLMRobertaSelfOutput,
|
||||
self_attention: XLMRobertaSelfAttention,
|
||||
}
|
||||
|
||||
impl XLMRobertaAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let output = XLMRobertaSelfOutput::new(cfg, vb.pp("output"))?;
|
||||
let self_attention = XLMRobertaSelfAttention::new(cfg, vb.pp("self"))?;
|
||||
Ok(Self {
|
||||
output,
|
||||
self_attention,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let self_outputs = self.self_attention.forward(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
past_key_value,
|
||||
encoder_attention_mask,
|
||||
)?;
|
||||
let attention_output = self.output.forward(&self_outputs, hidden_states)?;
|
||||
Ok((attention_output, self_outputs))
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaOutput {
|
||||
dense: Linear,
|
||||
layernorm: LayerNorm,
|
||||
}
|
||||
|
||||
impl XLMRobertaOutput {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layernorm =
|
||||
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self { dense, layernorm })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaIntermediate {
|
||||
dense: Linear,
|
||||
intermediate_act_fn: Activation,
|
||||
}
|
||||
|
||||
impl XLMRobertaIntermediate {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
|
||||
let intermediate_act_fn = cfg.hidden_act;
|
||||
Ok(Self {
|
||||
dense,
|
||||
intermediate_act_fn,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.intermediate_act_fn.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaLayer {
|
||||
attention: XLMRobertaAttention,
|
||||
intermediate: XLMRobertaIntermediate,
|
||||
output: XLMRobertaOutput,
|
||||
}
|
||||
|
||||
impl XLMRobertaLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let attention = XLMRobertaAttention::new(cfg, vb.pp("attention"))?;
|
||||
let intermediate = XLMRobertaIntermediate::new(cfg, vb.pp("intermediate"))?;
|
||||
let output = XLMRobertaOutput::new(cfg, vb.pp("output"))?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
intermediate,
|
||||
output,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let self_attention_outputs = self.attention.forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
)?;
|
||||
let attention_output = self_attention_outputs.0;
|
||||
let outputs = self_attention_outputs.1;
|
||||
let intermediate_output = self.intermediate.forward(&attention_output)?;
|
||||
let layer_output = self
|
||||
.output
|
||||
.forward(&intermediate_output, &attention_output)?;
|
||||
Ok((layer_output, outputs))
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaEncoder {
|
||||
layers: Vec<XLMRobertaLayer>,
|
||||
}
|
||||
|
||||
impl XLMRobertaEncoder {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
.map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{}", i))))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
) -> Result<Tensor> {
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
for layer_module in self.layers.iter() {
|
||||
let layer_outputs = layer_module.forward(
|
||||
&hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
)?;
|
||||
hidden_states = layer_outputs.0;
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct XLMRobertaModel {
|
||||
encoder: XLMRobertaEncoder,
|
||||
embeddings: XLMRobertaEmbeddings,
|
||||
}
|
||||
|
||||
impl XLMRobertaModel {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = XLMRobertaEncoder::new(cfg, vb.pp("encoder"))?;
|
||||
let embeddings = XLMRobertaEmbeddings::load(vb.pp("embeddings"), cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
token_type_ids: &Tensor,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let hidden_states = self.embeddings.forward(input_ids, token_type_ids)?;
|
||||
let attention_mask = prepare_4d_attention_mask(attention_mask, DType::F32, None)?
|
||||
.to_device(hidden_states.device())?;
|
||||
let hidden_states = self.encoder.forward(
|
||||
&hidden_states,
|
||||
&attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaLMHead {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl XLMRobertaLMHead {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm =
|
||||
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layer_norm"))?;
|
||||
Ok(Self { dense, layer_norm })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, shared_embeddings: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = candle_nn::Activation::Gelu.forward(&hidden_states)?;
|
||||
let hidden_states = self.layer_norm.forward(&hidden_states)?;
|
||||
let hidden_states = hidden_states.broadcast_matmul(shared_embeddings)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct XLMRobertaForMaskedLM {
|
||||
roberta: XLMRobertaModel,
|
||||
lm_head: XLMRobertaLMHead,
|
||||
}
|
||||
|
||||
impl XLMRobertaForMaskedLM {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
|
||||
let lm_head = XLMRobertaLMHead::new(cfg, vb.pp("lm_head"))?;
|
||||
Ok(Self { roberta, lm_head })
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
token_type_ids: &Tensor,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let hidden_states = self.roberta.forward(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_value,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)?;
|
||||
let lm_logits = self.lm_head.forward(
|
||||
&hidden_states,
|
||||
&self
|
||||
.roberta
|
||||
.embeddings
|
||||
.word_embeddings
|
||||
.embeddings()
|
||||
.t()?
|
||||
.unsqueeze(0)?,
|
||||
)?;
|
||||
Ok(lm_logits)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaClassificationHead {
|
||||
dense: Linear,
|
||||
out_proj: Linear,
|
||||
}
|
||||
|
||||
impl XLMRobertaClassificationHead {
|
||||
fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let out_proj = linear(cfg.hidden_size, num_labels, vb.pp("out_proj"))?;
|
||||
Ok(Self { dense, out_proj })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
|
||||
let hidden_states = self.dense.forward(&cls_states)?;
|
||||
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
|
||||
let hidden_states = self.out_proj.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct XLMRobertaForSequenceClassification {
|
||||
roberta: XLMRobertaModel,
|
||||
classifier: XLMRobertaClassificationHead,
|
||||
}
|
||||
|
||||
impl XLMRobertaForSequenceClassification {
|
||||
pub fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
|
||||
let classifier = XLMRobertaClassificationHead::new(num_labels, cfg, vb.pp("classifier"))?;
|
||||
Ok(Self {
|
||||
roberta,
|
||||
classifier,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
token_type_ids: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let hidden_states =
|
||||
self.roberta
|
||||
.forward(input_ids, attention_mask, token_type_ids, None, None, None)?;
|
||||
self.classifier.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_4d_attention_mask(
|
||||
mask: &Tensor,
|
||||
dtype: DType,
|
||||
tgt_len: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let bsz = mask.dim(0)?;
|
||||
let src_len = mask.dim(1)?;
|
||||
let tgt_len = tgt_len.unwrap_or(src_len);
|
||||
|
||||
let expanded_mask = mask
|
||||
.unsqueeze(1)?
|
||||
.unsqueeze(2)?
|
||||
.expand((bsz, 1, tgt_len, src_len))?
|
||||
.to_dtype(dtype)?;
|
||||
|
||||
let inverted_mask = (1.0 - expanded_mask)?;
|
||||
|
||||
(inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
fn get_dtype_min_val(dtype: DType) -> f64 {
|
||||
match dtype {
|
||||
DType::F32 => f32::MIN as f64,
|
||||
DType::F64 => f64::MIN,
|
||||
_ => panic!("Unsupported data type"),
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user