Merge pull request #89 from LaurentMazare/extending_bert

Enabling `roberta` for the example (it's the same model as Bert, with just different naming.)
This commit is contained in:
Nicolas Patry
2023-07-06 16:29:26 +02:00
committed by GitHub

View File

@ -9,7 +9,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType};
use clap::Parser;
use serde::Deserialize;
use std::collections::HashMap;
use tokenizers::Tokenizer;
use tokenizers::{PaddingParams, Tokenizer};
const DTYPE: DType = DType::F32;
@ -89,9 +89,10 @@ impl HiddenAct {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
#[default]
Absolute,
}
@ -110,9 +111,12 @@ struct Config {
initializer_range: f64,
layer_norm_eps: f64,
pad_token_id: usize,
#[serde(default)]
position_embedding_type: PositionEmbeddingType,
#[serde(default)]
use_cache: bool,
classifier_dropout: Option<f64>,
model_type: Option<String>,
}
impl Default for Config {
@ -133,6 +137,7 @@ impl Default for Config {
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
classifier_dropout: None,
model_type: Some("bert".to_string()),
}
}
}
@ -156,6 +161,7 @@ impl Config {
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
classifier_dropout: None,
model_type: Some("bert".to_string()),
}
}
}
@ -594,12 +600,17 @@ impl BertModel {
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
match (
BertEmbeddings::load("bert.embeddings", vb, config),
BertEncoder::load("bert.encoder", vb, config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
_ => return Err(err),
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config),
BertEncoder::load(&format!("{model_type}.encoder"), vb, config),
) {
(embeddings, encoder)
} else {
return Err(err);
}
} else {
return Err(err);
}
}
};
@ -731,6 +742,12 @@ async fn main() -> Result<()> {
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
@ -742,6 +759,7 @@ async fn main() -> Result<()> {
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("running inference on batch {:?}", token_ids.shape());