Add the varbuilder + check shapes.

This commit is contained in:
laurent
2023-07-03 15:32:20 +01:00
parent 895805be92
commit ad52b0377c
2 changed files with 163 additions and 62 deletions

View File

@ -10,6 +10,13 @@ pub enum Error {
got: DType, got: DType,
}, },
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedShape {
msg: String,
expected: Shape,
got: Shape,
},
#[error("{op}: dimension index {dim} out of range for {shape:?}")] #[error("{op}: dimension index {dim} out of range for {shape:?}")]
DimOutOfRange { DimOutOfRange {
shape: Shape, shape: Shape,

View File

@ -1,10 +1,68 @@
#![allow(dead_code)] #![allow(dead_code)]
use anyhow::Result as R; use anyhow::Result as R;
use candle::{DType, Device, Result, Tensor}; use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor};
use std::collections::HashMap;
const DTYPE: DType = DType::F32; const DTYPE: DType = DType::F32;
struct VarBuilder<'a> {
safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>,
dtype: DType,
device: Device,
}
impl<'a> VarBuilder<'a> {
pub fn from_safetensors(
safetensors: Vec<SafeTensors<'a>>,
dtype: DType,
device: Device,
) -> Self {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
for k in sf.names() {
routing.insert(k.to_string(), index);
}
}
Self {
safetensors: Some((routing, safetensors)),
device,
dtype,
}
}
pub fn zeros(dtype: DType, device: Device) -> Self {
Self {
safetensors: None,
device,
dtype,
}
}
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
let s: Shape = s.into();
match &self.safetensors {
None => Tensor::zeros(s, self.dtype, &self.device),
Some((routing, safetensors)) => {
// Unwrap or 0 just to let the proper error flow.
let index = routing.get(tensor_name).unwrap_or(&0);
let tensor = safetensors[*index]
.tensor(tensor_name, &self.device)?
.to_dtype(self.dtype)?;
if *tensor.shape() != s {
let msg = format!("shape mismatch for {tensor_name}");
Err(candle::Error::UnexpectedShape {
msg,
expected: s,
got: tensor.shape().clone(),
})?
}
Ok(tensor)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum HiddenAct { enum HiddenAct {
Gelu, Gelu,
@ -76,6 +134,11 @@ impl Embedding {
Self { embeddings } Self { embeddings }
} }
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let embeddings = vb.get((size1, size2), &format!("{p}.weight"))?;
Ok(Self::new(embeddings))
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> { fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Tensor::embedding(indexes, &self.embeddings) Tensor::embedding(indexes, &self.embeddings)
} }
@ -83,15 +146,23 @@ impl Embedding {
struct Linear { struct Linear {
weight: Tensor, weight: Tensor,
bias: Tensor,
} }
impl Linear { impl Linear {
fn new(weight: Tensor) -> Self { fn new(weight: Tensor, bias: Tensor) -> Self {
Self { weight } Self { weight, bias }
}
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get((size1, size2), &format!("{p}.weight"))?;
let bias = vb.get(size1, &format!("{p}.bias"))?;
Ok(Self::new(weight, bias))
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.weight.t()?)?; let x = x.matmul(&self.weight.t()?)?;
let x = x.broadcast_add(&self.bias)?;
Ok(x) Ok(x)
} }
} }
@ -112,12 +183,19 @@ impl Dropout {
} }
struct LayerNorm { struct LayerNorm {
scale: Tensor, weight: Tensor,
bias: Tensor,
} }
impl LayerNorm { impl LayerNorm {
fn new(scale: Tensor) -> Self { fn new(weight: Tensor, bias: Tensor) -> Self {
Self { scale } Self { weight, bias }
}
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(size, &format!("{p}.weight"))?;
let bias = vb.get(size, &format!("{p}.bias"))?;
Ok(Self { weight, bias })
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
@ -125,9 +203,9 @@ impl LayerNorm {
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1()?; let x = x_normed
let scale = self.scale.broadcast_as((seq_len, size))?; .broadcast_mul(&self.weight)?
let x = (scale * x_normed)?; .broadcast_add(&self.bias)?;
Ok(x) Ok(x)
} }
} }
@ -144,25 +222,34 @@ struct BertEmbeddings {
} }
impl BertEmbeddings { impl BertEmbeddings {
fn load(device: &Device, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings = let word_embeddings = Embedding::load(
Tensor::zeros((config.vocab_size, config.hidden_size), DTYPE, device)?; config.vocab_size,
let position_embeddings = Tensor::zeros( config.hidden_size,
(config.max_position_embeddings, config.hidden_size), &format!("{p}.word_embeddings"),
DTYPE, vb,
device,
)?; )?;
let token_type_embeddings = let position_embeddings = Embedding::load(
Tensor::zeros((config.type_vocab_size, config.hidden_size), DTYPE, device)?; config.max_position_embeddings,
let layer_norm = Tensor::zeros((), DTYPE, device)?; config.hidden_size,
&format!("{p}.position_embeddings"),
vb,
)?;
let token_type_embeddings = Embedding::load(
config.type_vocab_size,
config.hidden_size,
&format!("{p}.token_type_embeddings"),
vb,
)?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect(); let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
let position_ids = Tensor::new(&position_ids[..], device)?.unsqueeze(0)?; let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
let token_type_ids = position_ids.zeros_like()?; let token_type_ids = position_ids.zeros_like()?;
Ok(Self { Ok(Self {
word_embeddings: Embedding::new(word_embeddings), word_embeddings,
position_embeddings: Some(Embedding::new(position_embeddings)), position_embeddings: Some(position_embeddings),
token_type_embeddings: Embedding::new(token_type_embeddings), token_type_embeddings,
layer_norm: LayerNorm::new(layer_norm), layer_norm,
dropout: Dropout::new(config.hidden_dropout_prob), dropout: Dropout::new(config.hidden_dropout_prob),
position_ids, position_ids,
token_type_ids, token_type_ids,
@ -192,16 +279,14 @@ struct BertSelfAttention {
} }
impl BertSelfAttention { impl BertSelfAttention {
fn load(device: &Device, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads; let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size; let all_head_size = config.num_attention_heads * attention_head_size;
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
let query = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?; let hidden_size = config.hidden_size;
let query = Linear::new(query); let query = Linear::load(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
let value = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?; let value = Linear::load(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
let value = Linear::new(value); let key = Linear::load(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
let key = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
let key = Linear::new(key);
Ok(Self { Ok(Self {
query, query,
key, key,
@ -248,11 +333,14 @@ struct BertSelfOutput {
} }
impl BertSelfOutput { impl BertSelfOutput {
fn load(device: &Device, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = Tensor::zeros((config.hidden_size, config.hidden_size), DTYPE, device)?; let dense = Linear::load(
let dense = Linear::new(dense); config.hidden_size,
let layer_norm = Tensor::zeros((), DTYPE, device)?; config.hidden_size,
let layer_norm = LayerNorm::new(layer_norm); &format!("{p}.dense"),
vb,
)?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self { Ok(Self {
dense, dense,
@ -275,9 +363,9 @@ struct BertAttention {
} }
impl BertAttention { impl BertAttention {
fn load(device: &Device, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(device, config)?; let self_attention = BertSelfAttention::load(&format!("{p}.self_attention"), vb, config)?;
let self_output = BertSelfOutput::load(device, config)?; let self_output = BertSelfOutput::load(&format!("{p}.self_output"), vb, config)?;
Ok(Self { Ok(Self {
self_attention, self_attention,
self_output, self_output,
@ -298,13 +386,13 @@ struct BertIntermediate {
} }
impl BertIntermediate { impl BertIntermediate {
fn load(device: &Device, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = Tensor::zeros( let dense = Linear::load(
(config.hidden_size, config.intermediate_size), config.hidden_size,
DTYPE, config.intermediate_size,
device, &format!("{p}.dense"),
vb,
)?; )?;
let dense = Linear::new(dense);
Ok(Self { Ok(Self {
dense, dense,
intermediate_act: config.hidden_act, intermediate_act: config.hidden_act,
@ -325,15 +413,14 @@ struct BertOutput {
} }
impl BertOutput { impl BertOutput {
fn load(device: &Device, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = Tensor::zeros( let dense = Linear::load(
(config.intermediate_size, config.hidden_size), config.intermediate_size,
DTYPE, config.hidden_size,
device, &format!("{p}.dense"),
vb,
)?; )?;
let dense = Linear::new(dense); let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
let layer_norm = Tensor::zeros((), DTYPE, device)?;
let layer_norm = LayerNorm::new(layer_norm);
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self { Ok(Self {
dense, dense,
@ -357,10 +444,10 @@ struct BertLayer {
} }
impl BertLayer { impl BertLayer {
fn load(device: &Device, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let attention = BertAttention::load(device, config)?; let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?;
let intermediate = BertIntermediate::load(device, config)?; let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?;
let output = BertOutput::load(device, config)?; let output = BertOutput::load(&format!("{p}.output"), vb, config)?;
Ok(Self { Ok(Self {
attention, attention,
intermediate, intermediate,
@ -387,9 +474,12 @@ struct BertEncoder {
} }
impl BertEncoder { impl BertEncoder {
fn load(device: &Device, config: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers) let layers = (0..config.num_hidden_layers)
.map(|_index| BertLayer::load(device, config)) .map(|index| {
let p = format!("{p}.{index}");
BertLayer::load(&p, vb, config)
})
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(BertEncoder { layers }) Ok(BertEncoder { layers })
} }
@ -411,9 +501,9 @@ struct BertModel {
} }
impl BertModel { impl BertModel {
fn load(device: &Device, config: &Config) -> Result<Self> { fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
let embeddings = BertEmbeddings::load(device, config)?; let embeddings = BertEmbeddings::load("embeddings", vb, config)?;
let encoder = BertEncoder::load(device, config)?; let encoder = BertEncoder::load("encoder", vb, config)?;
Ok(Self { Ok(Self {
embeddings, embeddings,
encoder, encoder,
@ -428,5 +518,9 @@ impl BertModel {
} }
fn main() -> R<()> { fn main() -> R<()> {
let device = Device::Cpu;
let vb = VarBuilder::zeros(DTYPE, device);
let config = Config::default();
let _model = BertModel::load(&vb, &config)?;
Ok(()) Ok(())
} }