Add the varbuilder + check shapes.

This commit is contained in:
laurent
2023-07-03 15:32:20 +01:00
parent 895805be92
commit ad52b0377c
2 changed files with 163 additions and 62 deletions

View File

@ -10,6 +10,13 @@ pub enum Error {
got: DType,
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedShape {
msg: String,
expected: Shape,
got: Shape,
},
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
DimOutOfRange {
shape: Shape,

View File

@ -1,10 +1,68 @@
#![allow(dead_code)]
use anyhow::Result as R;
use candle::{DType, Device, Result, Tensor};
use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor};
use std::collections::HashMap;
const DTYPE: DType = DType::F32;
struct VarBuilder<'a> {
safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>,
dtype: DType,
device: Device,
}
impl<'a> VarBuilder<'a> {
pub fn from_safetensors(
safetensors: Vec<SafeTensors<'a>>,
dtype: DType,
device: Device,
) -> Self {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
for k in sf.names() {
routing.insert(k.to_string(), index);
}
}
Self {
safetensors: Some((routing, safetensors)),
device,
dtype,
}
}
pub fn zeros(dtype: DType, device: Device) -> Self {
Self {
safetensors: None,
device,
dtype,
}
}
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
let s: Shape = s.into();
match &self.safetensors {
None => Tensor::zeros(s, self.dtype, &self.device),
Some((routing, safetensors)) => {
// Unwrap or 0 just to let the proper error flow.
let index = routing.get(tensor_name).unwrap_or(&0);
let tensor = safetensors[*index]
.tensor(tensor_name, &self.device)?
.to_dtype(self.dtype)?;
if *tensor.shape() != s {
let msg = format!("shape mismatch for {tensor_name}");
Err(candle::Error::UnexpectedShape {
msg,
expected: s,
got: tensor.shape().clone(),
})?
}
Ok(tensor)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum HiddenAct {
Gelu,
@ -76,6 +134,11 @@ impl Embedding {
Self { embeddings }
}
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let embeddings = vb.get((size1, size2), &format!("{p}.weight"))?;
Ok(Self::new(embeddings))
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Tensor::embedding(indexes, &self.embeddings)
}
@ -83,15 +146,23 @@ impl Embedding {
struct Linear {
weight: Tensor,
bias: Tensor,
}
impl Linear {
fn new(weight: Tensor) -> Self {
Self { weight }
fn new(weight: Tensor, bias: Tensor) -> Self {
Self { weight, bias }
}
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get((size1, size2), &format!("{p}.weight"))?;
let bias = vb.get(size1, &format!("{p}.bias"))?;
Ok(Self::new(weight, bias))
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.weight.t()?)?;
let x = x.broadcast_add(&self.bias)?;
Ok(x)
}
}
@ -112,12 +183,19 @@ impl Dropout {
}
struct LayerNorm {
scale: Tensor,
weight: Tensor,
bias: Tensor,
}
impl LayerNorm {
fn new(scale: Tensor) -> Self {
Self { scale }
fn new(weight: Tensor, bias: Tensor) -> Self {
Self { weight, bias }
}
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(size, &format!("{p}.weight"))?;
let bias = vb.get(size, &format!("{p}.bias"))?;
Ok(Self { weight, bias })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
@ -125,9 +203,9 @@ impl LayerNorm {
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1()?;
let scale = self.scale.broadcast_as((seq_len, size))?;
let x = (scale * x_normed)?;
let x = x_normed
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
}
@ -144,25 +222,34 @@ struct BertEmbeddings {
}
impl BertEmbeddings {
fn load(device: &Device, config: &Config) -> Result<Self> {
let word_embeddings =
Tensor::zeros((config.vocab_size, config.hidden_size), DTYPE, device)?;
let position_embeddings = Tensor::zeros(
(config.max_position_embeddings, config.hidden_size),
DTYPE,
device,
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings = Embedding::load(
config.vocab_size,
config.hidden_size,
&format!("{p}.word_embeddings"),
vb,
)?;
let token_type_embeddings =
Tensor::zeros((config.type_vocab_size, config.hidden_size), DTYPE, device)?;
let layer_norm = Tensor::zeros((), DTYPE, device)?;
let position_embeddings = Embedding::load(
config.max_position_embeddings,
config.hidden_size,
&format!("{p}.position_embeddings"),
vb,
)?;
let token_type_embeddings = Embedding::load(
config.type_vocab_size,
config.hidden_size,
&format!("{p}.token_type_embeddings"),
vb,
)?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
let position_ids = Tensor::new(&position_ids[..], device)?.unsqueeze(0)?;
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
let token_type_ids = position_ids.zeros_like()?;
Ok(Self {
word_embeddings: Embedding::new(word_embeddings),
position_embeddings: Some(Embedding::new(position_embeddings)),
token_type_embeddings: Embedding::new(token_type_embeddings),
layer_norm: LayerNorm::new(layer_norm),
word_embeddings,
position_embeddings: Some(position_embeddings),
token_type_embeddings,
layer_norm,
dropout: Dropout::new(config.hidden_dropout_prob),
position_ids,
token_type_ids,
@ -192,16 +279,14 @@ struct BertSelfAttention {
}
impl BertSelfAttention {
fn load(device: &Device, config: &Config) -> Result<Self> {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
let dropout = Dropout::new(config.hidden_dropout_prob);
let query = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
let query = Linear::new(query);
let value = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
let value = Linear::new(value);
let key = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
let key = Linear::new(key);
let hidden_size = config.hidden_size;
let query = Linear::load(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
let value = Linear::load(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
let key = Linear::load(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
Ok(Self {
query,
key,
@ -248,11 +333,14 @@ struct BertSelfOutput {
}
impl BertSelfOutput {
fn load(device: &Device, config: &Config) -> Result<Self> {
let dense = Tensor::zeros((config.hidden_size, config.hidden_size), DTYPE, device)?;
let dense = Linear::new(dense);
let layer_norm = Tensor::zeros((), DTYPE, device)?;
let layer_norm = LayerNorm::new(layer_norm);
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = Linear::load(
config.hidden_size,
config.hidden_size,
&format!("{p}.dense"),
vb,
)?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
@ -275,9 +363,9 @@ struct BertAttention {
}
impl BertAttention {
fn load(device: &Device, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(device, config)?;
let self_output = BertSelfOutput::load(device, config)?;
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(&format!("{p}.self_attention"), vb, config)?;
let self_output = BertSelfOutput::load(&format!("{p}.self_output"), vb, config)?;
Ok(Self {
self_attention,
self_output,
@ -298,13 +386,13 @@ struct BertIntermediate {
}
impl BertIntermediate {
fn load(device: &Device, config: &Config) -> Result<Self> {
let dense = Tensor::zeros(
(config.hidden_size, config.intermediate_size),
DTYPE,
device,
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = Linear::load(
config.hidden_size,
config.intermediate_size,
&format!("{p}.dense"),
vb,
)?;
let dense = Linear::new(dense);
Ok(Self {
dense,
intermediate_act: config.hidden_act,
@ -325,15 +413,14 @@ struct BertOutput {
}
impl BertOutput {
fn load(device: &Device, config: &Config) -> Result<Self> {
let dense = Tensor::zeros(
(config.intermediate_size, config.hidden_size),
DTYPE,
device,
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = Linear::load(
config.intermediate_size,
config.hidden_size,
&format!("{p}.dense"),
vb,
)?;
let dense = Linear::new(dense);
let layer_norm = Tensor::zeros((), DTYPE, device)?;
let layer_norm = LayerNorm::new(layer_norm);
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
@ -357,10 +444,10 @@ struct BertLayer {
}
impl BertLayer {
fn load(device: &Device, config: &Config) -> Result<Self> {
let attention = BertAttention::load(device, config)?;
let intermediate = BertIntermediate::load(device, config)?;
let output = BertOutput::load(device, config)?;
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?;
let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?;
let output = BertOutput::load(&format!("{p}.output"), vb, config)?;
Ok(Self {
attention,
intermediate,
@ -387,9 +474,12 @@ struct BertEncoder {
}
impl BertEncoder {
fn load(device: &Device, config: &Config) -> Result<Self> {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|_index| BertLayer::load(device, config))
.map(|index| {
let p = format!("{p}.{index}");
BertLayer::load(&p, vb, config)
})
.collect::<Result<Vec<_>>>()?;
Ok(BertEncoder { layers })
}
@ -411,9 +501,9 @@ struct BertModel {
}
impl BertModel {
fn load(device: &Device, config: &Config) -> Result<Self> {
let embeddings = BertEmbeddings::load(device, config)?;
let encoder = BertEncoder::load(device, config)?;
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
let embeddings = BertEmbeddings::load("embeddings", vb, config)?;
let encoder = BertEncoder::load("encoder", vb, config)?;
Ok(Self {
embeddings,
encoder,
@ -428,5 +518,9 @@ impl BertModel {
}
fn main() -> R<()> {
let device = Device::Cpu;
let vb = VarBuilder::zeros(DTYPE, device);
let config = Config::default();
let _model = BertModel::load(&vb, &config)?;
Ok(())
}