mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Enabling roberta
for the example (it's the same model as Bert, with
just different naming.)
This commit is contained in:
@ -9,7 +9,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
@ -89,9 +89,10 @@ impl HiddenAct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum PositionEmbeddingType {
|
enum PositionEmbeddingType {
|
||||||
|
#[default]
|
||||||
Absolute,
|
Absolute,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,9 +111,12 @@ struct Config {
|
|||||||
initializer_range: f64,
|
initializer_range: f64,
|
||||||
layer_norm_eps: f64,
|
layer_norm_eps: f64,
|
||||||
pad_token_id: usize,
|
pad_token_id: usize,
|
||||||
|
#[serde(default)]
|
||||||
position_embedding_type: PositionEmbeddingType,
|
position_embedding_type: PositionEmbeddingType,
|
||||||
|
#[serde(default)]
|
||||||
use_cache: bool,
|
use_cache: bool,
|
||||||
classifier_dropout: Option<f64>,
|
classifier_dropout: Option<f64>,
|
||||||
|
model_type: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
@ -133,6 +137,7 @@ impl Default for Config {
|
|||||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
classifier_dropout: None,
|
classifier_dropout: None,
|
||||||
|
model_type: Some("bert".to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -156,6 +161,7 @@ impl Config {
|
|||||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
classifier_dropout: None,
|
classifier_dropout: None,
|
||||||
|
model_type: Some("bert".to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -594,12 +600,17 @@ impl BertModel {
|
|||||||
) {
|
) {
|
||||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
match (
|
if let Some(model_type) = &config.model_type {
|
||||||
BertEmbeddings::load("bert.embeddings", vb, config),
|
if let (Ok(embeddings), Ok(encoder)) = (
|
||||||
BertEncoder::load("bert.encoder", vb, config),
|
BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config),
|
||||||
) {
|
BertEncoder::load(&format!("{model_type}.encoder"), vb, config),
|
||||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
) {
|
||||||
_ => return Err(err),
|
(embeddings, encoder)
|
||||||
|
} else {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -731,6 +742,12 @@ async fn main() -> Result<()> {
|
|||||||
let n_sentences = sentences.len();
|
let n_sentences = sentences.len();
|
||||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||||
|
} else {
|
||||||
|
let pp = PaddingParams {
|
||||||
|
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
tokenizer.with_padding(Some(pp));
|
||||||
}
|
}
|
||||||
let tokens = tokenizer
|
let tokens = tokenizer
|
||||||
.encode_batch(sentences.to_vec(), true)
|
.encode_batch(sentences.to_vec(), true)
|
||||||
@ -742,6 +759,7 @@ async fn main() -> Result<()> {
|
|||||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
println!("running inference on batch {:?}", token_ids.shape());
|
println!("running inference on batch {:?}", token_ids.shape());
|
||||||
|
Reference in New Issue
Block a user