Get the tensors to be loaded properly.

This commit is contained in:
laurent
2023-07-03 15:53:31 +01:00
parent ad52b0377c
commit 54850e7525

View File

@ -1,7 +1,8 @@
#![allow(dead_code)] #![allow(dead_code)]
use anyhow::Result as R; use anyhow::Error as E;
use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor}; use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor};
use clap::Parser;
use std::collections::HashMap; use std::collections::HashMap;
const DTYPE: DType = DType::F32; const DTYPE: DType = DType::F32;
@ -125,6 +126,29 @@ impl Default for Config {
} }
} }
impl Config {
fn all_mini_lm_l6_v2() -> Self {
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
Self {
vocab_size: 30522,
hidden_size: 384,
num_hidden_layers: 6,
num_attention_heads: 12,
intermediate_size: 1536,
hidden_act: HiddenAct::Gelu,
hidden_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
classifier_dropout: None,
}
}
}
struct Embedding { struct Embedding {
embeddings: Tensor, embeddings: Tensor,
} }
@ -155,8 +179,8 @@ impl Linear {
} }
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> { fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get((size1, size2), &format!("{p}.weight"))?; let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let bias = vb.get(size1, &format!("{p}.bias"))?; let bias = vb.get(size2, &format!("{p}.bias"))?;
Ok(Self::new(weight, bias)) Ok(Self::new(weight, bias))
} }
@ -364,8 +388,8 @@ struct BertAttention {
impl BertAttention { impl BertAttention {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(&format!("{p}.self_attention"), vb, config)?; let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?;
let self_output = BertSelfOutput::load(&format!("{p}.self_output"), vb, config)?; let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?;
Ok(Self { Ok(Self {
self_attention, self_attention,
self_output, self_output,
@ -477,7 +501,7 @@ impl BertEncoder {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers) let layers = (0..config.num_hidden_layers)
.map(|index| { .map(|index| {
let p = format!("{p}.{index}"); let p = format!("{p}.layer.{index}");
BertLayer::load(&p, vb, config) BertLayer::load(&p, vb, config)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -517,10 +541,47 @@ impl BertModel {
} }
} }
fn main() -> R<()> { #[derive(Parser, Debug)]
let device = Device::Cpu; #[command(author, version, about, long_about = None)]
let vb = VarBuilder::zeros(DTYPE, device); struct Args {
let config = Config::default(); /// Run on CPU rather than on GPU.
let _model = BertModel::load(&vb, &config)?; #[arg(long)]
cpu: bool,
#[arg(long)]
tokenizer_config: String,
#[arg(long)]
weights: String,
}
fn main() -> anyhow::Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
let device = if args.cpu {
Device::Cpu
} else {
Device::new_cuda(0)?
};
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
let config = Config::all_mini_lm_l6_v2();
let model = BertModel::load(&vb, &config)?;
let tokens = tokenizer
.encode("This is an example sentence", true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?;
let position_ids: Vec<_> = (0..tokens.len() as u32).collect();
let position_ids = Tensor::new(&position_ids[..], &device)?.unsqueeze(0)?;
let ys = model.forward(&token_ids, &position_ids)?;
println!("{ys}");
Ok(()) Ok(())
} }