diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index d5de4296..2e82ab38 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -10,6 +10,13 @@ pub enum Error { got: DType, }, + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedShape { + msg: String, + expected: Shape, + got: Shape, + }, + #[error("{op}: dimension index {dim} out of range for {shape:?}")] DimOutOfRange { shape: Shape, diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 52685c91..57f6bf5b 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,10 +1,68 @@ #![allow(dead_code)] use anyhow::Result as R; -use candle::{DType, Device, Result, Tensor}; +use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor}; +use std::collections::HashMap; const DTYPE: DType = DType::F32; +struct VarBuilder<'a> { + safetensors: Option<(HashMap, Vec>)>, + dtype: DType, + device: Device, +} + +impl<'a> VarBuilder<'a> { + pub fn from_safetensors( + safetensors: Vec>, + dtype: DType, + device: Device, + ) -> Self { + let mut routing = HashMap::new(); + for (index, sf) in safetensors.iter().enumerate() { + for k in sf.names() { + routing.insert(k.to_string(), index); + } + } + Self { + safetensors: Some((routing, safetensors)), + device, + dtype, + } + } + + pub fn zeros(dtype: DType, device: Device) -> Self { + Self { + safetensors: None, + device, + dtype, + } + } + + pub fn get>(&self, s: S, tensor_name: &str) -> Result { + let s: Shape = s.into(); + match &self.safetensors { + None => Tensor::zeros(s, self.dtype, &self.device), + Some((routing, safetensors)) => { + // Unwrap or 0 just to let the proper error flow. + let index = routing.get(tensor_name).unwrap_or(&0); + let tensor = safetensors[*index] + .tensor(tensor_name, &self.device)? + .to_dtype(self.dtype)?; + if *tensor.shape() != s { + let msg = format!("shape mismatch for {tensor_name}"); + Err(candle::Error::UnexpectedShape { + msg, + expected: s, + got: tensor.shape().clone(), + })? + } + Ok(tensor) + } + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum HiddenAct { Gelu, @@ -76,6 +134,11 @@ impl Embedding { Self { embeddings } } + fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { + let embeddings = vb.get((size1, size2), &format!("{p}.weight"))?; + Ok(Self::new(embeddings)) + } + fn forward(&self, indexes: &Tensor) -> Result { Tensor::embedding(indexes, &self.embeddings) } @@ -83,15 +146,23 @@ impl Embedding { struct Linear { weight: Tensor, + bias: Tensor, } impl Linear { - fn new(weight: Tensor) -> Self { - Self { weight } + fn new(weight: Tensor, bias: Tensor) -> Self { + Self { weight, bias } + } + + fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get((size1, size2), &format!("{p}.weight"))?; + let bias = vb.get(size1, &format!("{p}.bias"))?; + Ok(Self::new(weight, bias)) } fn forward(&self, x: &Tensor) -> Result { let x = x.matmul(&self.weight.t()?)?; + let x = x.broadcast_add(&self.bias)?; Ok(x) } } @@ -112,12 +183,19 @@ impl Dropout { } struct LayerNorm { - scale: Tensor, + weight: Tensor, + bias: Tensor, } impl LayerNorm { - fn new(scale: Tensor) -> Self { - Self { scale } + fn new(weight: Tensor, bias: Tensor) -> Self { + Self { weight, bias } + } + + fn load(size: usize, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get(size, &format!("{p}.weight"))?; + let bias = vb.get(size, &format!("{p}.bias"))?; + Ok(Self { weight, bias }) } fn forward(&self, x: &Tensor) -> Result { @@ -125,9 +203,9 @@ impl LayerNorm { let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; - let size = self.scale.shape().r1()?; - let scale = self.scale.broadcast_as((seq_len, size))?; - let x = (scale * x_normed)?; + let x = x_normed + .broadcast_mul(&self.weight)? + .broadcast_add(&self.bias)?; Ok(x) } } @@ -144,25 +222,34 @@ struct BertEmbeddings { } impl BertEmbeddings { - fn load(device: &Device, config: &Config) -> Result { - let word_embeddings = - Tensor::zeros((config.vocab_size, config.hidden_size), DTYPE, device)?; - let position_embeddings = Tensor::zeros( - (config.max_position_embeddings, config.hidden_size), - DTYPE, - device, + fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { + let word_embeddings = Embedding::load( + config.vocab_size, + config.hidden_size, + &format!("{p}.word_embeddings"), + vb, )?; - let token_type_embeddings = - Tensor::zeros((config.type_vocab_size, config.hidden_size), DTYPE, device)?; - let layer_norm = Tensor::zeros((), DTYPE, device)?; + let position_embeddings = Embedding::load( + config.max_position_embeddings, + config.hidden_size, + &format!("{p}.position_embeddings"), + vb, + )?; + let token_type_embeddings = Embedding::load( + config.type_vocab_size, + config.hidden_size, + &format!("{p}.token_type_embeddings"), + vb, + )?; + let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?; let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect(); - let position_ids = Tensor::new(&position_ids[..], device)?.unsqueeze(0)?; + let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?; let token_type_ids = position_ids.zeros_like()?; Ok(Self { - word_embeddings: Embedding::new(word_embeddings), - position_embeddings: Some(Embedding::new(position_embeddings)), - token_type_embeddings: Embedding::new(token_type_embeddings), - layer_norm: LayerNorm::new(layer_norm), + word_embeddings, + position_embeddings: Some(position_embeddings), + token_type_embeddings, + layer_norm, dropout: Dropout::new(config.hidden_dropout_prob), position_ids, token_type_ids, @@ -192,16 +279,14 @@ struct BertSelfAttention { } impl BertSelfAttention { - fn load(device: &Device, config: &Config) -> Result { + fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { let attention_head_size = config.hidden_size / config.num_attention_heads; let all_head_size = config.num_attention_heads * attention_head_size; let dropout = Dropout::new(config.hidden_dropout_prob); - let query = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?; - let query = Linear::new(query); - let value = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?; - let value = Linear::new(value); - let key = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?; - let key = Linear::new(key); + let hidden_size = config.hidden_size; + let query = Linear::load(hidden_size, all_head_size, &format!("{p}.query"), vb)?; + let value = Linear::load(hidden_size, all_head_size, &format!("{p}.value"), vb)?; + let key = Linear::load(hidden_size, all_head_size, &format!("{p}.key"), vb)?; Ok(Self { query, key, @@ -248,11 +333,14 @@ struct BertSelfOutput { } impl BertSelfOutput { - fn load(device: &Device, config: &Config) -> Result { - let dense = Tensor::zeros((config.hidden_size, config.hidden_size), DTYPE, device)?; - let dense = Linear::new(dense); - let layer_norm = Tensor::zeros((), DTYPE, device)?; - let layer_norm = LayerNorm::new(layer_norm); + fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { + let dense = Linear::load( + config.hidden_size, + config.hidden_size, + &format!("{p}.dense"), + vb, + )?; + let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?; let dropout = Dropout::new(config.hidden_dropout_prob); Ok(Self { dense, @@ -275,9 +363,9 @@ struct BertAttention { } impl BertAttention { - fn load(device: &Device, config: &Config) -> Result { - let self_attention = BertSelfAttention::load(device, config)?; - let self_output = BertSelfOutput::load(device, config)?; + fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { + let self_attention = BertSelfAttention::load(&format!("{p}.self_attention"), vb, config)?; + let self_output = BertSelfOutput::load(&format!("{p}.self_output"), vb, config)?; Ok(Self { self_attention, self_output, @@ -298,13 +386,13 @@ struct BertIntermediate { } impl BertIntermediate { - fn load(device: &Device, config: &Config) -> Result { - let dense = Tensor::zeros( - (config.hidden_size, config.intermediate_size), - DTYPE, - device, + fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { + let dense = Linear::load( + config.hidden_size, + config.intermediate_size, + &format!("{p}.dense"), + vb, )?; - let dense = Linear::new(dense); Ok(Self { dense, intermediate_act: config.hidden_act, @@ -325,15 +413,14 @@ struct BertOutput { } impl BertOutput { - fn load(device: &Device, config: &Config) -> Result { - let dense = Tensor::zeros( - (config.intermediate_size, config.hidden_size), - DTYPE, - device, + fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { + let dense = Linear::load( + config.intermediate_size, + config.hidden_size, + &format!("{p}.dense"), + vb, )?; - let dense = Linear::new(dense); - let layer_norm = Tensor::zeros((), DTYPE, device)?; - let layer_norm = LayerNorm::new(layer_norm); + let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?; let dropout = Dropout::new(config.hidden_dropout_prob); Ok(Self { dense, @@ -357,10 +444,10 @@ struct BertLayer { } impl BertLayer { - fn load(device: &Device, config: &Config) -> Result { - let attention = BertAttention::load(device, config)?; - let intermediate = BertIntermediate::load(device, config)?; - let output = BertOutput::load(device, config)?; + fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { + let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?; + let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?; + let output = BertOutput::load(&format!("{p}.output"), vb, config)?; Ok(Self { attention, intermediate, @@ -387,9 +474,12 @@ struct BertEncoder { } impl BertEncoder { - fn load(device: &Device, config: &Config) -> Result { + fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { let layers = (0..config.num_hidden_layers) - .map(|_index| BertLayer::load(device, config)) + .map(|index| { + let p = format!("{p}.{index}"); + BertLayer::load(&p, vb, config) + }) .collect::>>()?; Ok(BertEncoder { layers }) } @@ -411,9 +501,9 @@ struct BertModel { } impl BertModel { - fn load(device: &Device, config: &Config) -> Result { - let embeddings = BertEmbeddings::load(device, config)?; - let encoder = BertEncoder::load(device, config)?; + fn load(vb: &VarBuilder, config: &Config) -> Result { + let embeddings = BertEmbeddings::load("embeddings", vb, config)?; + let encoder = BertEncoder::load("encoder", vb, config)?; Ok(Self { embeddings, encoder, @@ -428,5 +518,9 @@ impl BertModel { } fn main() -> R<()> { + let device = Device::Cpu; + let vb = VarBuilder::zeros(DTYPE, device); + let config = Config::default(); + let _model = BertModel::load(&vb, &config)?; Ok(()) }