Enabling roberta for the example (it's the same model as Bert, with

just different naming.)
This commit is contained in:
Nicolas Patry
2023-07-06 13:25:21 +02:00
parent c297a50960
commit 3f291bdf9d

View File

@ -9,7 +9,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType};
use clap::Parser; use clap::Parser;
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use tokenizers::Tokenizer; use tokenizers::{PaddingParams, Tokenizer};
const DTYPE: DType = DType::F32; const DTYPE: DType = DType::F32;
@ -89,9 +89,10 @@ impl HiddenAct {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
enum PositionEmbeddingType { enum PositionEmbeddingType {
#[default]
Absolute, Absolute,
} }
@ -110,9 +111,12 @@ struct Config {
initializer_range: f64, initializer_range: f64,
layer_norm_eps: f64, layer_norm_eps: f64,
pad_token_id: usize, pad_token_id: usize,
#[serde(default)]
position_embedding_type: PositionEmbeddingType, position_embedding_type: PositionEmbeddingType,
#[serde(default)]
use_cache: bool, use_cache: bool,
classifier_dropout: Option<f64>, classifier_dropout: Option<f64>,
model_type: Option<String>,
} }
impl Default for Config { impl Default for Config {
@ -133,6 +137,7 @@ impl Default for Config {
position_embedding_type: PositionEmbeddingType::Absolute, position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true, use_cache: true,
classifier_dropout: None, classifier_dropout: None,
model_type: Some("bert".to_string()),
} }
} }
} }
@ -156,6 +161,7 @@ impl Config {
position_embedding_type: PositionEmbeddingType::Absolute, position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true, use_cache: true,
classifier_dropout: None, classifier_dropout: None,
model_type: Some("bert".to_string()),
} }
} }
} }
@ -594,12 +600,17 @@ impl BertModel {
) { ) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder), (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => { (Err(err), _) | (_, Err(err)) => {
match ( if let Some(model_type) = &config.model_type {
BertEmbeddings::load("bert.embeddings", vb, config), if let (Ok(embeddings), Ok(encoder)) = (
BertEncoder::load("bert.encoder", vb, config), BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config),
) { BertEncoder::load(&format!("{model_type}.encoder"), vb, config),
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder), ) {
_ => return Err(err), (embeddings, encoder)
} else {
return Err(err);
}
} else {
return Err(err);
} }
} }
}; };
@ -731,6 +742,12 @@ async fn main() -> Result<()> {
let n_sentences = sentences.len(); let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() { if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
} }
let tokens = tokenizer let tokens = tokenizer
.encode_batch(sentences.to_vec(), true) .encode_batch(sentences.to_vec(), true)
@ -742,6 +759,7 @@ async fn main() -> Result<()> {
Ok(Tensor::new(tokens.as_slice(), device)?) Ok(Tensor::new(tokens.as_slice(), device)?)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?; let token_ids = Tensor::stack(&token_ids, 0)?;
let token_type_ids = token_ids.zeros_like()?; let token_type_ids = token_ids.zeros_like()?;
println!("running inference on batch {:?}", token_ids.shape()); println!("running inference on batch {:?}", token_ids.shape());