Get the tensors to be loaded properly.

This commit is contained in:
laurent
2023-07-03 15:53:31 +01:00
parent ad52b0377c
commit 54850e7525

View File

@ -1,7 +1,8 @@
#![allow(dead_code)]
use anyhow::Result as R;
use anyhow::Error as E;
use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor};
use clap::Parser;
use std::collections::HashMap;
const DTYPE: DType = DType::F32;
@ -125,6 +126,29 @@ impl Default for Config {
}
}
impl Config {
fn all_mini_lm_l6_v2() -> Self {
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
Self {
vocab_size: 30522,
hidden_size: 384,
num_hidden_layers: 6,
num_attention_heads: 12,
intermediate_size: 1536,
hidden_act: HiddenAct::Gelu,
hidden_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
classifier_dropout: None,
}
}
}
struct Embedding {
embeddings: Tensor,
}
@ -155,8 +179,8 @@ impl Linear {
}
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get((size1, size2), &format!("{p}.weight"))?;
let bias = vb.get(size1, &format!("{p}.bias"))?;
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let bias = vb.get(size2, &format!("{p}.bias"))?;
Ok(Self::new(weight, bias))
}
@ -364,8 +388,8 @@ struct BertAttention {
impl BertAttention {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(&format!("{p}.self_attention"), vb, config)?;
let self_output = BertSelfOutput::load(&format!("{p}.self_output"), vb, config)?;
let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?;
let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?;
Ok(Self {
self_attention,
self_output,
@ -477,7 +501,7 @@ impl BertEncoder {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| {
let p = format!("{p}.{index}");
let p = format!("{p}.layer.{index}");
BertLayer::load(&p, vb, config)
})
.collect::<Result<Vec<_>>>()?;
@ -517,10 +541,47 @@ impl BertModel {
}
}
fn main() -> R<()> {
let device = Device::Cpu;
let vb = VarBuilder::zeros(DTYPE, device);
let config = Config::default();
let _model = BertModel::load(&vb, &config)?;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
tokenizer_config: String,
#[arg(long)]
weights: String,
}
fn main() -> anyhow::Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
let device = if args.cpu {
Device::Cpu
} else {
Device::new_cuda(0)?
};
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
let config = Config::all_mini_lm_l6_v2();
let model = BertModel::load(&vb, &config)?;
let tokens = tokenizer
.encode("This is an example sentence", true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?;
let position_ids: Vec<_> = (0..tokens.len() as u32).collect();
let position_ids = Tensor::new(&position_ids[..], &device)?.unsqueeze(0)?;
let ys = model.forward(&token_ids, &position_ids)?;
println!("{ys}");
Ok(())
}