mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
VarBuilder path creation (#131)
* Use a struct for the safetensor+routing. * Group the path and the var-builder together. * Fix for the empty path case.
This commit is contained in:
@ -169,7 +169,7 @@ fn main() -> Result<()> {
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
|
||||
let config = Config::falcon7b();
|
||||
config.validate()?;
|
||||
let model = Falcon::load(&vb, config)?;
|
||||
let model = Falcon::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
||||
|
@ -4,27 +4,21 @@ use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (
|
||||
vb.get(size, &format!("{p}.weight")),
|
||||
vb.get(size, &format!("{p}.bias")),
|
||||
) {
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (
|
||||
vb.get(size, &format!("{p}.gamma")),
|
||||
vb.get(size, &format!("{p}.beta")),
|
||||
) {
|
||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err.into());
|
||||
@ -50,8 +44,8 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
@ -164,14 +158,14 @@ struct FalconRotaryEmbedding {
|
||||
}
|
||||
|
||||
impl FalconRotaryEmbedding {
|
||||
fn load(vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(device: &Device, cfg: &Config) -> Result<Self> {
|
||||
let head_dim = cfg.head_dim();
|
||||
let inv_freq: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
Ok(Self {
|
||||
inv_freq: Tensor::new(inv_freq.as_slice(), &vb.device)?,
|
||||
inv_freq: Tensor::new(inv_freq.as_slice(), device)?,
|
||||
cache: None,
|
||||
})
|
||||
}
|
||||
@ -237,9 +231,9 @@ struct FalconAttention {
|
||||
}
|
||||
|
||||
impl FalconAttention {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let maybe_rotary = if cfg.rotary() {
|
||||
let rotary = FalconRotaryEmbedding::load(vb, cfg)?;
|
||||
let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?;
|
||||
Some(rotary)
|
||||
} else {
|
||||
None
|
||||
@ -251,20 +245,8 @@ impl FalconAttention {
|
||||
} else {
|
||||
3 * hidden_size
|
||||
};
|
||||
let query_key_value = linear(
|
||||
hidden_size,
|
||||
qkv_out_dim,
|
||||
cfg.bias,
|
||||
&format!("{p}.query_key_value"),
|
||||
vb,
|
||||
)?;
|
||||
let dense = linear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
cfg.bias,
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?;
|
||||
let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?;
|
||||
Ok(Self {
|
||||
query_key_value,
|
||||
dense,
|
||||
@ -367,11 +349,11 @@ struct FalconMlp {
|
||||
}
|
||||
|
||||
impl FalconMlp {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let b = cfg.bias;
|
||||
let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
|
||||
let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
|
||||
let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?;
|
||||
let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?;
|
||||
let dropout = Dropout::new(cfg.hidden_dropout);
|
||||
Ok(Self {
|
||||
dense_h_to_4h,
|
||||
@ -397,23 +379,21 @@ struct FalconDecoderLayer {
|
||||
}
|
||||
|
||||
impl FalconDecoderLayer {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?;
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?;
|
||||
let inp_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.input_layernorm"),
|
||||
vb,
|
||||
vb.pp("input_layernorm"),
|
||||
)?;
|
||||
let self_attention = FalconAttention::load(&format!("{p}.self_attention"), vb, cfg)?;
|
||||
let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?;
|
||||
let post_attention_layernorm = if cfg.parallel_attn {
|
||||
None
|
||||
} else {
|
||||
let ln = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.post_attention_layernorm"),
|
||||
vb,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Some(ln)
|
||||
};
|
||||
@ -480,23 +460,21 @@ impl Falcon {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn load(vb: &VarBuilder, cfg: Config) -> Result<Self> {
|
||||
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||
let word_embeddings = embedding(
|
||||
cfg.vocab_size,
|
||||
cfg.hidden_size,
|
||||
"transformer.word_embeddings",
|
||||
vb,
|
||||
vb.pp("transformer.word_embeddings"),
|
||||
)?;
|
||||
let blocks = (0..cfg.num_hidden_layers)
|
||||
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
|
||||
.map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
"transformer.ln_f",
|
||||
vb,
|
||||
vb.pp("transformer.ln_f"),
|
||||
)?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
blocks,
|
||||
|
Reference in New Issue
Block a user