From 3f291bdf9dbddd1cf00744d621aee02010c7d503 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 6 Jul 2023 13:25:21 +0200 Subject: [PATCH] Enabling `roberta` for the example (it's the same model as Bert, with just different naming.) --- candle-examples/examples/bert/main.rs | 34 ++++++++++++++++++++------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index bf99b1bf..f9f83c4a 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -9,7 +9,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType}; use clap::Parser; use serde::Deserialize; use std::collections::HashMap; -use tokenizers::Tokenizer; +use tokenizers::{PaddingParams, Tokenizer}; const DTYPE: DType = DType::F32; @@ -89,9 +89,10 @@ impl HiddenAct { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[serde(rename_all = "lowercase")] enum PositionEmbeddingType { + #[default] Absolute, } @@ -110,9 +111,12 @@ struct Config { initializer_range: f64, layer_norm_eps: f64, pad_token_id: usize, + #[serde(default)] position_embedding_type: PositionEmbeddingType, + #[serde(default)] use_cache: bool, classifier_dropout: Option, + model_type: Option, } impl Default for Config { @@ -133,6 +137,7 @@ impl Default for Config { position_embedding_type: PositionEmbeddingType::Absolute, use_cache: true, classifier_dropout: None, + model_type: Some("bert".to_string()), } } } @@ -156,6 +161,7 @@ impl Config { position_embedding_type: PositionEmbeddingType::Absolute, use_cache: true, classifier_dropout: None, + model_type: Some("bert".to_string()), } } } @@ -594,12 +600,17 @@ impl BertModel { ) { (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), (Err(err), _) | (_, Err(err)) => { - match ( - BertEmbeddings::load("bert.embeddings", vb, config), - BertEncoder::load("bert.encoder", vb, config), - ) { - (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), - _ => return Err(err), + if let Some(model_type) = &config.model_type { + if let (Ok(embeddings), Ok(encoder)) = ( + BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config), + BertEncoder::load(&format!("{model_type}.encoder"), vb, config), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } else { + return Err(err); } } }; @@ -731,6 +742,12 @@ async fn main() -> Result<()> { let n_sentences = sentences.len(); if let Some(pp) = tokenizer.get_padding_mut() { pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); } let tokens = tokenizer .encode_batch(sentences.to_vec(), true) @@ -742,6 +759,7 @@ async fn main() -> Result<()> { Ok(Tensor::new(tokens.as_slice(), device)?) }) .collect::>>()?; + let token_ids = Tensor::stack(&token_ids, 0)?; let token_type_ids = token_ids.zeros_like()?; println!("running inference on batch {:?}", token_ids.shape());