mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Get the tensors to be loaded properly.
This commit is contained in:
@ -1,7 +1,8 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::Result as R;
|
||||
use anyhow::Error as E;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor};
|
||||
use clap::Parser;
|
||||
use std::collections::HashMap;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
@ -125,6 +126,29 @@ impl Default for Config {
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn all_mini_lm_l6_v2() -> Self {
|
||||
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
|
||||
Self {
|
||||
vocab_size: 30522,
|
||||
hidden_size: 384,
|
||||
num_hidden_layers: 6,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 1536,
|
||||
hidden_act: HiddenAct::Gelu,
|
||||
hidden_dropout_prob: 0.1,
|
||||
max_position_embeddings: 512,
|
||||
type_vocab_size: 2,
|
||||
initializer_range: 0.02,
|
||||
layer_norm_eps: 1e-12,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
classifier_dropout: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding {
|
||||
embeddings: Tensor,
|
||||
}
|
||||
@ -155,8 +179,8 @@ impl Linear {
|
||||
}
|
||||
|
||||
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get((size1, size2), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size1, &format!("{p}.bias"))?;
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
||||
Ok(Self::new(weight, bias))
|
||||
}
|
||||
|
||||
@ -364,8 +388,8 @@ struct BertAttention {
|
||||
|
||||
impl BertAttention {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let self_attention = BertSelfAttention::load(&format!("{p}.self_attention"), vb, config)?;
|
||||
let self_output = BertSelfOutput::load(&format!("{p}.self_output"), vb, config)?;
|
||||
let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?;
|
||||
let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?;
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
self_output,
|
||||
@ -477,7 +501,7 @@ impl BertEncoder {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.num_hidden_layers)
|
||||
.map(|index| {
|
||||
let p = format!("{p}.{index}");
|
||||
let p = format!("{p}.layer.{index}");
|
||||
BertLayer::load(&p, vb, config)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
@ -517,10 +541,47 @@ impl BertModel {
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> R<()> {
|
||||
let device = Device::Cpu;
|
||||
let vb = VarBuilder::zeros(DTYPE, device);
|
||||
let config = Config::default();
|
||||
let _model = BertModel::load(&vb, &config)?;
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_config: String,
|
||||
|
||||
#[arg(long)]
|
||||
weights: String,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
} else {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
||||
let config = Config::all_mini_lm_l6_v2();
|
||||
let model = BertModel::load(&vb, &config)?;
|
||||
|
||||
let tokens = tokenizer
|
||||
.encode("This is an example sentence", true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], &device)?;
|
||||
let position_ids: Vec<_> = (0..tokens.len() as u32).collect();
|
||||
let position_ids = Tensor::new(&position_ids[..], &device)?.unsqueeze(0)?;
|
||||
let ys = model.forward(&token_ids, &position_ids)?;
|
||||
println!("{ys}");
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user