From b46c28a2ac2a88387590c65a2efef028f010b29e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 10 Jul 2023 22:37:34 +0100 Subject: [PATCH] VarBuilder path creation (#131) * Use a struct for the safetensor+routing. * Group the path and the var-builder together. * Fix for the empty path case. --- candle-examples/examples/bert/main.rs | 114 ++++++++-------------- candle-examples/examples/falcon/main.rs | 2 +- candle-examples/examples/falcon/model.rs | 76 +++++---------- candle-examples/examples/whisper/model.rs | 101 ++++++++----------- candle-nn/src/var_builder.rs | 103 +++++++++++++++---- 5 files changed, 196 insertions(+), 200 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 3871c752..d0d600ee 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -109,14 +109,14 @@ impl Config { } } -fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } -fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - let bias = vb.get(size2, &format!("{p}.bias"))?; +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; Ok(Linear::new(weight, Some(bias))) } @@ -135,17 +135,11 @@ impl Dropout { } } -fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { - let (weight, bias) = match ( - vb.get(size, &format!("{p}.weight")), - vb.get(size, &format!("{p}.bias")), - ) { +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { (Ok(weight), Ok(bias)) => (weight, bias), (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = ( - vb.get(size, &format!("{p}.gamma")), - vb.get(size, &format!("{p}.beta")), - ) { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { (weight, bias) } else { return Err(err.into()); @@ -167,33 +161,29 @@ struct BertEmbeddings { } impl BertEmbeddings { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { + fn load(vb: VarBuilder, config: &Config) -> Result { let word_embeddings = embedding( config.vocab_size, config.hidden_size, - &format!("{p}.word_embeddings"), - vb, + vb.pp("word_embeddings"), )?; let position_embeddings = embedding( config.max_position_embeddings, config.hidden_size, - &format!("{p}.position_embeddings"), - vb, + vb.pp("position_embeddings"), )?; let token_type_embeddings = embedding( config.type_vocab_size, config.hidden_size, - &format!("{p}.token_type_embeddings"), - vb, + vb.pp("token_type_embeddings"), )?; let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, - &format!("{p}.LayerNorm"), - vb, + vb.pp("LayerNorm"), )?; let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect(); - let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?; + let position_ids = Tensor::new(&position_ids[..], vb.device())?.unsqueeze(0)?; let token_type_ids = position_ids.zeros_like()?; Ok(Self { word_embeddings, @@ -233,14 +223,14 @@ struct BertSelfAttention { } impl BertSelfAttention { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { + fn load(vb: VarBuilder, config: &Config) -> Result { let attention_head_size = config.hidden_size / config.num_attention_heads; let all_head_size = config.num_attention_heads * attention_head_size; let dropout = Dropout::new(config.hidden_dropout_prob); let hidden_size = config.hidden_size; - let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?; - let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?; - let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?; + let query = linear(hidden_size, all_head_size, vb.pp("query"))?; + let value = linear(hidden_size, all_head_size, vb.pp("value"))?; + let key = linear(hidden_size, all_head_size, vb.pp("key"))?; Ok(Self { query, key, @@ -289,18 +279,12 @@ struct BertSelfOutput { } impl BertSelfOutput { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let dense = linear( - config.hidden_size, - config.hidden_size, - &format!("{p}.dense"), - vb, - )?; + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, - &format!("{p}.LayerNorm"), - vb, + vb.pp("LayerNorm"), )?; let dropout = Dropout::new(config.hidden_dropout_prob); Ok(Self { @@ -324,9 +308,9 @@ struct BertAttention { } impl BertAttention { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?; - let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?; + fn load(vb: VarBuilder, config: &Config) -> Result { + let self_attention = BertSelfAttention::load(vb.pp("self"), config)?; + let self_output = BertSelfOutput::load(vb.pp("output"), config)?; Ok(Self { self_attention, self_output, @@ -347,13 +331,8 @@ struct BertIntermediate { } impl BertIntermediate { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let dense = linear( - config.hidden_size, - config.intermediate_size, - &format!("{p}.dense"), - vb, - )?; + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; Ok(Self { dense, intermediate_act: config.hidden_act, @@ -375,18 +354,12 @@ struct BertOutput { } impl BertOutput { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let dense = linear( - config.intermediate_size, - config.hidden_size, - &format!("{p}.dense"), - vb, - )?; + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, - &format!("{p}.LayerNorm"), - vb, + vb.pp("LayerNorm"), )?; let dropout = Dropout::new(config.hidden_dropout_prob); Ok(Self { @@ -411,10 +384,10 @@ struct BertLayer { } impl BertLayer { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { - let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?; - let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?; - let output = BertOutput::load(&format!("{p}.output"), vb, config)?; + fn load(vb: VarBuilder, config: &Config) -> Result { + let attention = BertAttention::load(vb.pp("attention"), config)?; + let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?; + let output = BertOutput::load(vb.pp("output"), config)?; Ok(Self { attention, intermediate, @@ -441,12 +414,9 @@ struct BertEncoder { } impl BertEncoder { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result { + fn load(vb: VarBuilder, config: &Config) -> Result { let layers = (0..config.num_hidden_layers) - .map(|index| { - let p = format!("{p}.layer.{index}"); - BertLayer::load(&p, vb, config) - }) + .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config)) .collect::>>()?; Ok(BertEncoder { layers }) } @@ -469,17 +439,17 @@ struct BertModel { } impl BertModel { - fn load(vb: &VarBuilder, config: &Config) -> Result { + fn load(vb: VarBuilder, config: &Config) -> Result { let (embeddings, encoder) = match ( - BertEmbeddings::load("embeddings", vb, config), - BertEncoder::load("encoder", vb, config), + BertEmbeddings::load(vb.pp("embeddings"), config), + BertEncoder::load(vb.pp("encoder"), config), ) { (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), (Err(err), _) | (_, Err(err)) => { if let Some(model_type) = &config.model_type { if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config), - BertEncoder::load(&format!("{model_type}.encoder"), vb, config), + BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), + BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), ) { (embeddings, encoder) } else { @@ -493,7 +463,7 @@ impl BertModel { Ok(Self { embeddings, encoder, - device: vb.device.clone(), + device: vb.device().clone(), }) } @@ -576,7 +546,7 @@ impl Args { let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); - let model = BertModel::load(&vb, &config)?; + let model = BertModel::load(vb, &config)?; Ok((model, tokenizer)) } } diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 84de6480..a59a0349 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -169,7 +169,7 @@ fn main() -> Result<()> { let vb = VarBuilder::from_safetensors(weights, DTYPE, &device); let config = Config::falcon7b(); config.validate()?; - let model = Falcon::load(&vb, config)?; + let model = Falcon::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index 1300e7cb..631ff280 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -4,27 +4,21 @@ use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; -fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; +fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; let bias = if bias { - Some(vb.get(size2, &format!("{p}.bias"))?) + Some(vb.get(size2, "bias")?) } else { None }; Ok(Linear::new(weight, bias)) } -fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { - let (weight, bias) = match ( - vb.get(size, &format!("{p}.weight")), - vb.get(size, &format!("{p}.bias")), - ) { +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { (Ok(weight), Ok(bias)) => (weight, bias), (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = ( - vb.get(size, &format!("{p}.gamma")), - vb.get(size, &format!("{p}.beta")), - ) { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { (weight, bias) } else { return Err(err.into()); @@ -50,8 +44,8 @@ impl Dropout { } } -fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } @@ -164,14 +158,14 @@ struct FalconRotaryEmbedding { } impl FalconRotaryEmbedding { - fn load(vb: &VarBuilder, cfg: &Config) -> Result { + fn load(device: &Device, cfg: &Config) -> Result { let head_dim = cfg.head_dim(); let inv_freq: Vec<_> = (0..head_dim) .step_by(2) .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) .collect(); Ok(Self { - inv_freq: Tensor::new(inv_freq.as_slice(), &vb.device)?, + inv_freq: Tensor::new(inv_freq.as_slice(), device)?, cache: None, }) } @@ -237,9 +231,9 @@ struct FalconAttention { } impl FalconAttention { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let maybe_rotary = if cfg.rotary() { - let rotary = FalconRotaryEmbedding::load(vb, cfg)?; + let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?; Some(rotary) } else { None @@ -251,20 +245,8 @@ impl FalconAttention { } else { 3 * hidden_size }; - let query_key_value = linear( - hidden_size, - qkv_out_dim, - cfg.bias, - &format!("{p}.query_key_value"), - vb, - )?; - let dense = linear( - hidden_size, - hidden_size, - cfg.bias, - &format!("{p}.dense"), - vb, - )?; + let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?; + let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?; Ok(Self { query_key_value, dense, @@ -367,11 +349,11 @@ struct FalconMlp { } impl FalconMlp { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let h = cfg.hidden_size; let b = cfg.bias; - let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?; - let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?; + let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?; + let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?; let dropout = Dropout::new(cfg.hidden_dropout); Ok(Self { dense_h_to_4h, @@ -397,23 +379,21 @@ struct FalconDecoderLayer { } impl FalconDecoderLayer { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?; + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?; let inp_layernorm = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, - &format!("{p}.input_layernorm"), - vb, + vb.pp("input_layernorm"), )?; - let self_attention = FalconAttention::load(&format!("{p}.self_attention"), vb, cfg)?; + let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?; let post_attention_layernorm = if cfg.parallel_attn { None } else { let ln = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, - &format!("{p}.post_attention_layernorm"), - vb, + vb.pp("post_attention_layernorm"), )?; Some(ln) }; @@ -480,23 +460,21 @@ impl Falcon { &self.config } - pub fn load(vb: &VarBuilder, cfg: Config) -> Result { + pub fn load(vb: VarBuilder, cfg: Config) -> Result { let word_embeddings = embedding( cfg.vocab_size, cfg.hidden_size, - "transformer.word_embeddings", - vb, + vb.pp("transformer.word_embeddings"), )?; let blocks = (0..cfg.num_hidden_layers) - .map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg)) + .map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg)) .collect::>>()?; let ln_f = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, - "transformer.ln_f", - vb, + vb.pp("transformer.ln_f"), )?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; Ok(Self { word_embeddings, blocks, diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index d653d0c7..ece8b2d8 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -38,19 +38,19 @@ impl Config { } } -fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } -fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - let bias = vb.get(size2, &format!("{p}.bias"))?; +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; Ok(Linear::new(weight, Some(bias))) } -fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; +fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; Ok(Linear::new(weight, None)) } @@ -59,14 +59,10 @@ fn conv1d( out_channels: usize, kernel_size: usize, config: Conv1dConfig, - p: &str, - vb: &VarBuilder, + vb: VarBuilder, ) -> Result { - let weight = vb.get( - (out_channels, in_channels, kernel_size), - &format!("{p}.weight"), - )?; - let bias = vb.get(out_channels, &format!("{p}.bias"))?; + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; + let bias = vb.get(out_channels, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) } @@ -75,13 +71,9 @@ fn conv1d_no_bias( out_channels: usize, kernel_size: usize, config: Conv1dConfig, - p: &str, - vb: &VarBuilder, + vb: VarBuilder, ) -> Result { - let weight = vb.get( - (out_channels, in_channels, kernel_size), - &format!("{p}.weight"), - )?; + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; Ok(Conv1d::new(weight, None, config)) } @@ -100,9 +92,9 @@ impl Dropout { } } -fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get(size, &format!("{p}.weight"))?; - let bias = vb.get(size, &format!("{p}.bias"))?; +fn layer_norm(size: usize, vb: VarBuilder) -> Result { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; Ok(LayerNorm::new(weight, bias, 1e-5)) } @@ -116,11 +108,11 @@ struct MultiHeadAttention { } impl MultiHeadAttention { - fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result { - let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?; - let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?; - let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?; - let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?; + fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result { + let query = linear(n_state, n_state, vb.pp("q_proj"))?; + let value = linear(n_state, n_state, vb.pp("v_proj"))?; + let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; + let out = linear(n_state, n_state, vb.pp("out_proj"))?; Ok(Self { query, key, @@ -179,21 +171,20 @@ struct ResidualAttentionBlock { } impl ResidualAttentionBlock { - fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result { - let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?; - let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?; + fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result { + let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; + let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; let cross_attn = if ca { - let cross_attn = - MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?; - let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?; + let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; + let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; Some((cross_attn, cross_attn_ln)) } else { None }; let n_mlp = n_state * 4; - let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?; - let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?; - let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?; + let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; + let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; + let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; Ok(Self { attn, attn_ln, @@ -245,7 +236,7 @@ pub struct AudioEncoder { } impl AudioEncoder { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let n_state = cfg.d_model; let n_head = cfg.encoder_attention_heads; let n_ctx = cfg.max_source_positions; @@ -257,22 +248,15 @@ impl AudioEncoder { padding: 1, stride: 2, }; - let conv1 = conv1d( - cfg.num_mel_bins, - n_state, - 3, - cfg1, - &format!("{p}.conv1"), - vb, - )?; - let conv2 = conv1d(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; - let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?; + let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; + let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; let blocks = (0..cfg.encoder_layers) .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb) + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) }) .collect::>>()?; - let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?; + let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; Ok(Self { conv1, conv2, @@ -306,23 +290,22 @@ pub struct TextDecoder { } impl TextDecoder { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let n_state = cfg.d_model; let n_head = cfg.decoder_attention_heads; let n_ctx = cfg.max_target_positions; - let token_embedding = embedding(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?; - let positional_embedding = - vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?; + let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; + let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; let blocks = (0..cfg.decoder_layers) .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb) + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) }) .collect::>>()?; - let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?; + let ln = layer_norm(n_state, vb.pp("layer_norm"))?; let mask: Vec<_> = (0..n_ctx) .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .collect(); - let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), &vb.device)?; + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; Ok(Self { token_embedding, @@ -361,8 +344,8 @@ pub struct Whisper { impl Whisper { pub fn load(vb: &VarBuilder, config: Config) -> Result { - let encoder = AudioEncoder::load("model.encoder", vb, &config)?; - let decoder = TextDecoder::load("model.decoder", vb, &config)?; + let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; + let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; Ok(Self { encoder, decoder, diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 203640b0..d71b5822 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,53 +1,118 @@ use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use std::collections::HashMap; +use std::sync::Arc; -pub struct VarBuilder<'a> { - safetensors: Option<(HashMap, Vec>)>, +struct SafeTensorWithRouting<'a> { + routing: HashMap, + safetensors: Vec>, +} + +struct TensorData<'a> { + // TODO: Make this part generic, probably via some Box to avoid too much generics. + safetensors: Option>, pub dtype: DType, pub device: Device, } -impl<'a> VarBuilder<'a> { - pub fn from_safetensors( - safetensors: Vec>, - dtype: DType, - device: &Device, - ) -> Self { +impl<'a> TensorData<'a> { + fn from_safetensors(safetensors: Vec>, dtype: DType, device: &Device) -> Self { let mut routing = HashMap::new(); for (index, sf) in safetensors.iter().enumerate() { for k in sf.names() { routing.insert(k.to_string(), index); } } + let safetensors = SafeTensorWithRouting { + routing, + safetensors, + }; Self { - safetensors: Some((routing, safetensors)), + safetensors: Some(safetensors), device: device.clone(), dtype, } } - pub fn zeros(dtype: DType, device: Device) -> Self { + fn zeros(dtype: DType, device: &Device) -> Self { Self { safetensors: None, - device, + device: device.clone(), dtype, } } +} +#[derive(Clone)] +pub struct VarBuilder<'a> { + data: Arc>, + path: Vec, +} + +impl<'a> VarBuilder<'a> { + /// Create a `VarBuilder` accessing data frome the safetensors storage. The initial path is + /// set to the root path and sub-paths can be created via the `push_prefix` method. + pub fn from_safetensors(st: Vec>, dtype: DType, device: &Device) -> Self { + let data = TensorData::from_safetensors(st, dtype, device); + Self { + data: Arc::new(data), + path: vec![], + } + } + + pub fn zeros(dtype: DType, device: &Device) -> Self { + let data = TensorData::zeros(dtype, device); + Self { + data: Arc::new(data), + path: vec![], + } + } + + pub fn push_prefix(&self, s: &str) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + } + } + + /// Short alias for `push_prefix`. + pub fn pp(&self, s: &str) -> Self { + self.push_prefix(s) + } + + pub fn device(&self) -> &Device { + &self.data.device + } + + pub fn dtype(&self) -> DType { + self.data.dtype + } +} + +impl<'a> VarBuilder<'a> { pub fn get>(&self, s: S, tensor_name: &str) -> candle::Result { + let data = self.data.as_ref(); let s: Shape = s.into(); - match &self.safetensors { - None => Tensor::zeros(s, self.dtype, &self.device), - Some((routing, safetensors)) => { + match &self.data.safetensors { + None => Tensor::zeros(s, data.dtype, &data.device), + Some(SafeTensorWithRouting { + routing, + safetensors, + }) => { + let path = if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + }; // Unwrap or 0 just to let the proper error flow. - let index = routing.get(tensor_name).unwrap_or(&0); + let index = routing.get(&path).unwrap_or(&0); let tensor = safetensors[*index] - .tensor(tensor_name, &self.device)? - .to_dtype(self.dtype)?; + .tensor(&path, &data.device)? + .to_dtype(data.dtype)?; if *tensor.shape() != s { - let msg = format!("shape mismatch for {tensor_name}"); Err(candle::Error::UnexpectedShape { - msg, + msg: format!("shape mismatch for {path}"), expected: s, got: tensor.shape().clone(), })?