mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Merge pull request #69 from LaurentMazare/upgrade_bert
Upgrading bert example to work with `bert-base-uncased`.
This commit is contained in:
@ -12,6 +12,8 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", default-features=false }
|
candle = { path = "../candle-core", default-features=false }
|
||||||
|
serde = { version = "1.0.166", features = ["derive"] }
|
||||||
|
serde_json = "1.0.99"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
// The tokenizer.json and weights should be retrieved from:
|
use anyhow::{anyhow, Error as E, Result};
|
||||||
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||||
|
use candle_hub::{api::Api, Cache, Repo, RepoType};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
@ -66,7 +65,8 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
enum HiddenAct {
|
enum HiddenAct {
|
||||||
Gelu,
|
Gelu,
|
||||||
Relu,
|
Relu,
|
||||||
@ -84,13 +84,14 @@ impl HiddenAct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
enum PositionEmbeddingType {
|
enum PositionEmbeddingType {
|
||||||
Absolute,
|
Absolute,
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
@ -235,8 +236,22 @@ impl LayerNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
let (weight, bias) = match (
|
||||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
vb.get(size, &format!("{p}.weight")),
|
||||||
|
vb.get(size, &format!("{p}.bias")),
|
||||||
|
) {
|
||||||
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
|
if let (Ok(weight), Ok(bias)) = (
|
||||||
|
vb.get(size, &format!("{p}.gamma")),
|
||||||
|
vb.get(size, &format!("{p}.beta")),
|
||||||
|
) {
|
||||||
|
(weight, bias)
|
||||||
|
} else {
|
||||||
|
return Err(err.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
Ok(Self { weight, bias, eps })
|
Ok(Self { weight, bias, eps })
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -567,8 +582,21 @@ struct BertModel {
|
|||||||
|
|
||||||
impl BertModel {
|
impl BertModel {
|
||||||
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let embeddings = BertEmbeddings::load("embeddings", vb, config)?;
|
let (embeddings, encoder) = match (
|
||||||
let encoder = BertEncoder::load("encoder", vb, config)?;
|
BertEmbeddings::load("embeddings", vb, config),
|
||||||
|
BertEncoder::load("encoder", vb, config),
|
||||||
|
) {
|
||||||
|
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||||
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
|
match (
|
||||||
|
BertEmbeddings::load("bert.embeddings", vb, config),
|
||||||
|
BertEncoder::load("bert.encoder", vb, config),
|
||||||
|
) {
|
||||||
|
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||||
|
_ => return Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
embeddings,
|
embeddings,
|
||||||
encoder,
|
encoder,
|
||||||
@ -589,15 +617,30 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Run offline (you must have the files already cached)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer_config: String,
|
offline: bool,
|
||||||
|
|
||||||
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weights: String,
|
revision: Option<String>,
|
||||||
|
|
||||||
|
/// The number of times to run the prompt.
|
||||||
|
#[arg(long, default_value = "This is an example sentence")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The number of times to run the prompt.
|
||||||
|
#[arg(long, default_value = "1")]
|
||||||
|
n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = if args.cpu {
|
let device = if args.cpu {
|
||||||
@ -606,24 +649,60 @@ fn main() -> Result<()> {
|
|||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
||||||
|
let default_revision = "refs/pr/21".to_string();
|
||||||
|
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||||
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
|
(None, Some(revision)) => (default_model, revision),
|
||||||
|
(None, None) => (default_model, default_revision),
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
|
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
|
||||||
|
let cache = Cache::default();
|
||||||
|
(
|
||||||
|
cache
|
||||||
|
.get(&repo, "config.json")
|
||||||
|
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||||
|
cache
|
||||||
|
.get(&repo, "tokenizer.json")
|
||||||
|
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||||
|
cache
|
||||||
|
.get(&repo, "model.safetensors")
|
||||||
|
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
let api = Api::new()?;
|
||||||
|
(
|
||||||
|
api.get(&repo, "config.json").await?,
|
||||||
|
api.get(&repo, "tokenizer.json").await?,
|
||||||
|
api.get(&repo, "model.safetensors").await?,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
|
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
||||||
let config = Config::all_mini_lm_l6_v2();
|
|
||||||
let model = BertModel::load(&vb, &config)?;
|
let model = BertModel::load(&vb, &config)?;
|
||||||
|
|
||||||
let tokens = tokenizer
|
let tokens = tokenizer
|
||||||
.encode("This is an example sentence", true)
|
.encode(args.prompt, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||||
println!("{token_ids}");
|
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
println!("{ys}");
|
for _ in 0..args.n {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let _ys = model.forward(&token_ids, &token_type_ids)?;
|
||||||
|
println!("Took {:?}", start.elapsed());
|
||||||
|
// println!("Ys {:?}", ys.shape());
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user