mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
VarBuilder path creation (#131)
* Use a struct for the safetensor+routing. * Group the path and the var-builder together. * Fix for the empty path case.
This commit is contained in:
@ -109,14 +109,14 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = vb.get(size2, "bias")?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
@ -135,17 +135,11 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (
|
||||
vb.get(size, &format!("{p}.weight")),
|
||||
vb.get(size, &format!("{p}.bias")),
|
||||
) {
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (
|
||||
vb.get(size, &format!("{p}.gamma")),
|
||||
vb.get(size, &format!("{p}.beta")),
|
||||
) {
|
||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err.into());
|
||||
@ -167,33 +161,29 @@ struct BertEmbeddings {
|
||||
}
|
||||
|
||||
impl BertEmbeddings {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings = embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.word_embeddings"),
|
||||
vb,
|
||||
vb.pp("word_embeddings"),
|
||||
)?;
|
||||
let position_embeddings = embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
&format!("{p}.position_embeddings"),
|
||||
vb,
|
||||
vb.pp("position_embeddings"),
|
||||
)?;
|
||||
let token_type_embeddings = embedding(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.token_type_embeddings"),
|
||||
vb,
|
||||
vb.pp("token_type_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
&format!("{p}.LayerNorm"),
|
||||
vb,
|
||||
vb.pp("LayerNorm"),
|
||||
)?;
|
||||
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
|
||||
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
|
||||
let position_ids = Tensor::new(&position_ids[..], vb.device())?.unsqueeze(0)?;
|
||||
let token_type_ids = position_ids.zeros_like()?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
@ -233,14 +223,14 @@ struct BertSelfAttention {
|
||||
}
|
||||
|
||||
impl BertSelfAttention {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let hidden_size = config.hidden_size;
|
||||
let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
|
||||
let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
|
||||
let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
|
||||
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
|
||||
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
|
||||
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
@ -289,18 +279,12 @@ struct BertSelfOutput {
|
||||
}
|
||||
|
||||
impl BertSelfOutput {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
&format!("{p}.LayerNorm"),
|
||||
vb,
|
||||
vb.pp("LayerNorm"),
|
||||
)?;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
@ -324,9 +308,9 @@ struct BertAttention {
|
||||
}
|
||||
|
||||
impl BertAttention {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?;
|
||||
let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?;
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
|
||||
let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
self_output,
|
||||
@ -347,13 +331,8 @@ struct BertIntermediate {
|
||||
}
|
||||
|
||||
impl BertIntermediate {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
intermediate_act: config.hidden_act,
|
||||
@ -375,18 +354,12 @@ struct BertOutput {
|
||||
}
|
||||
|
||||
impl BertOutput {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
&format!("{p}.LayerNorm"),
|
||||
vb,
|
||||
vb.pp("LayerNorm"),
|
||||
)?;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
@ -411,10 +384,10 @@ struct BertLayer {
|
||||
}
|
||||
|
||||
impl BertLayer {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?;
|
||||
let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?;
|
||||
let output = BertOutput::load(&format!("{p}.output"), vb, config)?;
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention = BertAttention::load(vb.pp("attention"), config)?;
|
||||
let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
|
||||
let output = BertOutput::load(vb.pp("output"), config)?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
intermediate,
|
||||
@ -441,12 +414,9 @@ struct BertEncoder {
|
||||
}
|
||||
|
||||
impl BertEncoder {
|
||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.num_hidden_layers)
|
||||
.map(|index| {
|
||||
let p = format!("{p}.layer.{index}");
|
||||
BertLayer::load(&p, vb, config)
|
||||
})
|
||||
.map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(BertEncoder { layers })
|
||||
}
|
||||
@ -469,17 +439,17 @@ struct BertModel {
|
||||
}
|
||||
|
||||
impl BertModel {
|
||||
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let (embeddings, encoder) = match (
|
||||
BertEmbeddings::load("embeddings", vb, config),
|
||||
BertEncoder::load("encoder", vb, config),
|
||||
BertEmbeddings::load(vb.pp("embeddings"), config),
|
||||
BertEncoder::load(vb.pp("encoder"), config),
|
||||
) {
|
||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let Some(model_type) = &config.model_type {
|
||||
if let (Ok(embeddings), Ok(encoder)) = (
|
||||
BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config),
|
||||
BertEncoder::load(&format!("{model_type}.encoder"), vb, config),
|
||||
BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||
BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
|
||||
) {
|
||||
(embeddings, encoder)
|
||||
} else {
|
||||
@ -493,7 +463,7 @@ impl BertModel {
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
device: vb.device.clone(),
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -576,7 +546,7 @@ impl Args {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let model = BertModel::load(&vb, &config)?;
|
||||
let model = BertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user