Upgrading bert example to work with bert-base-uncased.

- Always take weights from the hub
- Optional `model_id` + `revision` to use safetensors version
  potentially
- Optional loading for `bert-base-uncased` (`weight` vs `gamma`).
- Take the config from the hub.
This commit is contained in:
Nicolas Patry
2023-07-04 14:12:14 +00:00
parent a8b38ff821
commit 43a007cba4
2 changed files with 63 additions and 16 deletions

View File

@ -1,10 +1,9 @@
#![allow(dead_code)]
// The tokenizer.json and weights should be retrieved from:
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
use anyhow::{Error as E, Result};
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use candle_hub::{api::Api, Repo, RepoType};
use clap::Parser;
use serde::Deserialize;
use std::collections::HashMap;
const DTYPE: DType = DType::F32;
@ -66,7 +65,8 @@ impl<'a> VarBuilder<'a> {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
enum HiddenAct {
Gelu,
Relu,
@ -84,13 +84,14 @@ impl HiddenAct {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
Absolute,
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, Deserialize)]
struct Config {
vocab_size: usize,
hidden_size: usize,
@ -235,8 +236,22 @@ impl LayerNorm {
}
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(size, &format!("{p}.weight"))?;
let bias = vb.get(size, &format!("{p}.bias"))?;
let (weight, bias) = match (
vb.get(size, &format!("{p}.weight")),
vb.get(size, &format!("{p}.bias")),
) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (
vb.get(size, &format!("{p}.gamma")),
vb.get(size, &format!("{p}.beta")),
) {
(weight, bias)
} else {
return Err(err.into());
}
}
};
Ok(Self { weight, bias, eps })
}
@ -567,8 +582,21 @@ struct BertModel {
impl BertModel {
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
let embeddings = BertEmbeddings::load("embeddings", vb, config)?;
let encoder = BertEncoder::load("encoder", vb, config)?;
let (embeddings, encoder) = match (
BertEmbeddings::load("embeddings", vb, config),
BertEncoder::load("encoder", vb, config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
match (
BertEmbeddings::load("bert.embeddings", vb, config),
BertEncoder::load("bert.encoder", vb, config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
_ => return Err(err),
}
}
};
Ok(Self {
embeddings,
encoder,
@ -590,13 +618,14 @@ struct Args {
cpu: bool,
#[arg(long)]
tokenizer_config: String,
model_id: Option<String>,
#[arg(long)]
weights: String,
revision: Option<String>,
}
fn main() -> Result<()> {
#[tokio::main]
async fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
@ -606,13 +635,29 @@ fn main() -> Result<()> {
Device::new_cuda(0)?
};
let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
let default_revision = "refs/pr/21".to_string();
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let api = Api::new()?;
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
println!("building the model");
let config_filename = api.get(&repo, "config.json").await?;
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
let weights = api.get(&repo, "model.safetensors").await?;
let weights = unsafe { candle::safetensors::MmapedFile::new(weights)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
let config = Config::all_mini_lm_l6_v2();
let model = BertModel::load(&vb, &config)?;
let tokens = tokenizer