mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Enabling roberta
for the example (it's the same model as Bert, with
just different naming.)
This commit is contained in:
@ -9,7 +9,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType};
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
@ -89,9 +89,10 @@ impl HiddenAct {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
#[default]
|
||||
Absolute,
|
||||
}
|
||||
|
||||
@ -110,9 +111,12 @@ struct Config {
|
||||
initializer_range: f64,
|
||||
layer_norm_eps: f64,
|
||||
pad_token_id: usize,
|
||||
#[serde(default)]
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
#[serde(default)]
|
||||
use_cache: bool,
|
||||
classifier_dropout: Option<f64>,
|
||||
model_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@ -133,6 +137,7 @@ impl Default for Config {
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
classifier_dropout: None,
|
||||
model_type: Some("bert".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -156,6 +161,7 @@ impl Config {
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
classifier_dropout: None,
|
||||
model_type: Some("bert".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -594,12 +600,17 @@ impl BertModel {
|
||||
) {
|
||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
match (
|
||||
BertEmbeddings::load("bert.embeddings", vb, config),
|
||||
BertEncoder::load("bert.encoder", vb, config),
|
||||
) {
|
||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||
_ => return Err(err),
|
||||
if let Some(model_type) = &config.model_type {
|
||||
if let (Ok(embeddings), Ok(encoder)) = (
|
||||
BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config),
|
||||
BertEncoder::load(&format!("{model_type}.encoder"), vb, config),
|
||||
) {
|
||||
(embeddings, encoder)
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -731,6 +742,12 @@ async fn main() -> Result<()> {
|
||||
let n_sentences = sentences.len();
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
let tokens = tokenizer
|
||||
.encode_batch(sentences.to_vec(), true)
|
||||
@ -742,6 +759,7 @@ async fn main() -> Result<()> {
|
||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
println!("running inference on batch {:?}", token_ids.shape());
|
||||
|
Reference in New Issue
Block a user