diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 57f6bf5b..58f42f38 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,7 +1,8 @@ #![allow(dead_code)] -use anyhow::Result as R; +use anyhow::Error as E; use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor}; +use clap::Parser; use std::collections::HashMap; const DTYPE: DType = DType::F32; @@ -125,6 +126,29 @@ impl Default for Config { } } +impl Config { + fn all_mini_lm_l6_v2() -> Self { + // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json + Self { + vocab_size: 30522, + hidden_size: 384, + num_hidden_layers: 6, + num_attention_heads: 12, + intermediate_size: 1536, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + } + } +} + struct Embedding { embeddings: Tensor, } @@ -155,8 +179,8 @@ impl Linear { } fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size1, size2), &format!("{p}.weight"))?; - let bias = vb.get(size1, &format!("{p}.bias"))?; + let weight = vb.get((size2, size1), &format!("{p}.weight"))?; + let bias = vb.get(size2, &format!("{p}.bias"))?; Ok(Self::new(weight, bias)) } @@ -364,8 +388,8 @@ struct BertAttention { impl BertAttention { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let self_attention = BertSelfAttention::load(&format!("{p}.self_attention"), vb, config)?; - let self_output = BertSelfOutput::load(&format!("{p}.self_output"), vb, config)?; + let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?; + let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?; Ok(Self { self_attention, self_output, @@ -477,7 +501,7 @@ impl BertEncoder { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { let layers = (0..config.num_hidden_layers) .map(|index| { - let p = format!("{p}.{index}"); + let p = format!("{p}.layer.{index}"); BertLayer::load(&p, vb, config) }) .collect::>>()?; @@ -517,10 +541,47 @@ impl BertModel { } } -fn main() -> R<()> { - let device = Device::Cpu; - let vb = VarBuilder::zeros(DTYPE, device); - let config = Config::default(); - let _model = BertModel::load(&vb, &config)?; +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + tokenizer_config: String, + + #[arg(long)] + weights: String, +} + +fn main() -> anyhow::Result<()> { + use tokenizers::Tokenizer; + + let args = Args::parse(); + let device = if args.cpu { + Device::Cpu + } else { + Device::new_cuda(0)? + }; + + let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; + + let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); + let config = Config::all_mini_lm_l6_v2(); + let model = BertModel::load(&vb, &config)?; + + let tokens = tokenizer + .encode("This is an example sentence", true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], &device)?; + let position_ids: Vec<_> = (0..tokens.len() as u32).collect(); + let position_ids = Tensor::new(&position_ids[..], &device)?.unsqueeze(0)?; + let ys = model.forward(&token_ids, &position_ids)?; + println!("{ys}"); Ok(()) }