mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add the varbuilder + check shapes.
This commit is contained in:
@ -10,6 +10,13 @@ pub enum Error {
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedShape {
|
||||
msg: String,
|
||||
expected: Shape,
|
||||
got: Shape,
|
||||
},
|
||||
|
||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||
DimOutOfRange {
|
||||
shape: Shape,
|
||||
|
@ -1,10 +1,68 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::Result as R;
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
struct VarBuilder<'a> {
|
||||
safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
pub fn from_safetensors(
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
) -> Self {
|
||||
let mut routing = HashMap::new();
|
||||
for (index, sf) in safetensors.iter().enumerate() {
|
||||
for k in sf.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
}
|
||||
Self {
|
||||
safetensors: Some((routing, safetensors)),
|
||||
device,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn zeros(dtype: DType, device: Device) -> Self {
|
||||
Self {
|
||||
safetensors: None,
|
||||
device,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
||||
let s: Shape = s.into();
|
||||
match &self.safetensors {
|
||||
None => Tensor::zeros(s, self.dtype, &self.device),
|
||||
Some((routing, safetensors)) => {
|
||||
// Unwrap or 0 just to let the proper error flow.
|
||||
let index = routing.get(tensor_name).unwrap_or(&0);
|
||||
let tensor = safetensors[*index]
|
||||
.tensor(tensor_name, &self.device)?
|
||||
.to_dtype(self.dtype)?;
|
||||
if *tensor.shape() != s {
|
||||
let msg = format!("shape mismatch for {tensor_name}");
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg,
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
})?
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum HiddenAct {
|
||||
Gelu,
|
||||
@ -76,6 +134,11 @@ impl Embedding {
|
||||
Self { embeddings }
|
||||
}
|
||||
|
||||
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((size1, size2), &format!("{p}.weight"))?;
|
||||
Ok(Self::new(embeddings))
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
Tensor::embedding(indexes, &self.embeddings)
|
||||
}
|
||||
@ -83,15 +146,23 @@ impl Embedding {
|
||||
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn new(weight: Tensor) -> Self {
|
||||
Self { weight }
|
||||
fn new(weight: Tensor, bias: Tensor) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
||||
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get((size1, size2), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size1, &format!("{p}.bias"))?;
|
||||
Ok(Self::new(weight, bias))
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight.t()?)?;
|
||||
let x = x.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
@ -112,12 +183,19 @@ impl Dropout {
|
||||
}
|
||||
|
||||
struct LayerNorm {
|
||||
scale: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
fn new(scale: Tensor) -> Self {
|
||||
Self { scale }
|
||||
fn new(weight: Tensor, bias: Tensor) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
||||
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
||||
Ok(Self { weight, bias })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
@ -125,9 +203,9 @@ impl LayerNorm {
|
||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
let scale = self.scale.broadcast_as((seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
let x = x_normed
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
@ -144,25 +222,34 @@ struct BertEmbeddings {
|
||||
}
|
||||
|
||||
impl BertEmbeddings {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let word_embeddings =
|
||||
Tensor::zeros((config.vocab_size, config.hidden_size), DTYPE, device)?;
|
||||
let position_embeddings = Tensor::zeros(
|
||||
(config.max_position_embeddings, config.hidden_size),
|
||||
DTYPE,
|
||||
device,
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings = Embedding::load(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.word_embeddings"),
|
||||
vb,
|
||||
)?;
|
||||
let token_type_embeddings =
|
||||
Tensor::zeros((config.type_vocab_size, config.hidden_size), DTYPE, device)?;
|
||||
let layer_norm = Tensor::zeros((), DTYPE, device)?;
|
||||
let position_embeddings = Embedding::load(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
&format!("{p}.position_embeddings"),
|
||||
vb,
|
||||
)?;
|
||||
let token_type_embeddings = Embedding::load(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.token_type_embeddings"),
|
||||
vb,
|
||||
)?;
|
||||
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
|
||||
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
|
||||
let position_ids = Tensor::new(&position_ids[..], device)?.unsqueeze(0)?;
|
||||
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
|
||||
let token_type_ids = position_ids.zeros_like()?;
|
||||
Ok(Self {
|
||||
word_embeddings: Embedding::new(word_embeddings),
|
||||
position_embeddings: Some(Embedding::new(position_embeddings)),
|
||||
token_type_embeddings: Embedding::new(token_type_embeddings),
|
||||
layer_norm: LayerNorm::new(layer_norm),
|
||||
word_embeddings,
|
||||
position_embeddings: Some(position_embeddings),
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
dropout: Dropout::new(config.hidden_dropout_prob),
|
||||
position_ids,
|
||||
token_type_ids,
|
||||
@ -192,16 +279,14 @@ struct BertSelfAttention {
|
||||
}
|
||||
|
||||
impl BertSelfAttention {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let query = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
|
||||
let query = Linear::new(query);
|
||||
let value = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
|
||||
let value = Linear::new(value);
|
||||
let key = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
|
||||
let key = Linear::new(key);
|
||||
let hidden_size = config.hidden_size;
|
||||
let query = Linear::load(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
|
||||
let value = Linear::load(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
|
||||
let key = Linear::load(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
@ -248,11 +333,14 @@ struct BertSelfOutput {
|
||||
}
|
||||
|
||||
impl BertSelfOutput {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let dense = Tensor::zeros((config.hidden_size, config.hidden_size), DTYPE, device)?;
|
||||
let dense = Linear::new(dense);
|
||||
let layer_norm = Tensor::zeros((), DTYPE, device)?;
|
||||
let layer_norm = LayerNorm::new(layer_norm);
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = Linear::load(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
@ -275,9 +363,9 @@ struct BertAttention {
|
||||
}
|
||||
|
||||
impl BertAttention {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let self_attention = BertSelfAttention::load(device, config)?;
|
||||
let self_output = BertSelfOutput::load(device, config)?;
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let self_attention = BertSelfAttention::load(&format!("{p}.self_attention"), vb, config)?;
|
||||
let self_output = BertSelfOutput::load(&format!("{p}.self_output"), vb, config)?;
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
self_output,
|
||||
@ -298,13 +386,13 @@ struct BertIntermediate {
|
||||
}
|
||||
|
||||
impl BertIntermediate {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let dense = Tensor::zeros(
|
||||
(config.hidden_size, config.intermediate_size),
|
||||
DTYPE,
|
||||
device,
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = Linear::load(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
let dense = Linear::new(dense);
|
||||
Ok(Self {
|
||||
dense,
|
||||
intermediate_act: config.hidden_act,
|
||||
@ -325,15 +413,14 @@ struct BertOutput {
|
||||
}
|
||||
|
||||
impl BertOutput {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let dense = Tensor::zeros(
|
||||
(config.intermediate_size, config.hidden_size),
|
||||
DTYPE,
|
||||
device,
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = Linear::load(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
let dense = Linear::new(dense);
|
||||
let layer_norm = Tensor::zeros((), DTYPE, device)?;
|
||||
let layer_norm = LayerNorm::new(layer_norm);
|
||||
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
@ -357,10 +444,10 @@ struct BertLayer {
|
||||
}
|
||||
|
||||
impl BertLayer {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let attention = BertAttention::load(device, config)?;
|
||||
let intermediate = BertIntermediate::load(device, config)?;
|
||||
let output = BertOutput::load(device, config)?;
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?;
|
||||
let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?;
|
||||
let output = BertOutput::load(&format!("{p}.output"), vb, config)?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
intermediate,
|
||||
@ -387,9 +474,12 @@ struct BertEncoder {
|
||||
}
|
||||
|
||||
impl BertEncoder {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.num_hidden_layers)
|
||||
.map(|_index| BertLayer::load(device, config))
|
||||
.map(|index| {
|
||||
let p = format!("{p}.{index}");
|
||||
BertLayer::load(&p, vb, config)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(BertEncoder { layers })
|
||||
}
|
||||
@ -411,9 +501,9 @@ struct BertModel {
|
||||
}
|
||||
|
||||
impl BertModel {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let embeddings = BertEmbeddings::load(device, config)?;
|
||||
let encoder = BertEncoder::load(device, config)?;
|
||||
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let embeddings = BertEmbeddings::load("embeddings", vb, config)?;
|
||||
let encoder = BertEncoder::load("encoder", vb, config)?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
@ -428,5 +518,9 @@ impl BertModel {
|
||||
}
|
||||
|
||||
fn main() -> R<()> {
|
||||
let device = Device::Cpu;
|
||||
let vb = VarBuilder::zeros(DTYPE, device);
|
||||
let config = Config::default();
|
||||
let _model = BertModel::load(&vb, &config)?;
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user