VarBuilder path creation (#131)

* Use a struct for the safetensor+routing.

* Group the path and the var-builder together.

* Fix for the empty path case.
This commit is contained in:
Laurent Mazare
2023-07-10 22:37:34 +01:00
committed by GitHub
parent 1aa7fbbc33
commit b46c28a2ac
5 changed files with 196 additions and 200 deletions

View File

@ -109,14 +109,14 @@ impl Config {
}
}
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let bias = vb.get(size2, &format!("{p}.bias"))?;
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
@ -135,17 +135,11 @@ impl Dropout {
}
}
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (
vb.get(size, &format!("{p}.weight")),
vb.get(size, &format!("{p}.bias")),
) {
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (
vb.get(size, &format!("{p}.gamma")),
vb.get(size, &format!("{p}.beta")),
) {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err.into());
@ -167,33 +161,29 @@ struct BertEmbeddings {
}
impl BertEmbeddings {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings = embedding(
config.vocab_size,
config.hidden_size,
&format!("{p}.word_embeddings"),
vb,
vb.pp("word_embeddings"),
)?;
let position_embeddings = embedding(
config.max_position_embeddings,
config.hidden_size,
&format!("{p}.position_embeddings"),
vb,
vb.pp("position_embeddings"),
)?;
let token_type_embeddings = embedding(
config.type_vocab_size,
config.hidden_size,
&format!("{p}.token_type_embeddings"),
vb,
vb.pp("token_type_embeddings"),
)?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
vb.pp("LayerNorm"),
)?;
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
let position_ids = Tensor::new(&position_ids[..], vb.device())?.unsqueeze(0)?;
let token_type_ids = position_ids.zeros_like()?;
Ok(Self {
word_embeddings,
@ -233,14 +223,14 @@ struct BertSelfAttention {
}
impl BertSelfAttention {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
let dropout = Dropout::new(config.hidden_dropout_prob);
let hidden_size = config.hidden_size;
let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
Ok(Self {
query,
key,
@ -289,18 +279,12 @@ struct BertSelfOutput {
}
impl BertSelfOutput {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(
config.hidden_size,
config.hidden_size,
&format!("{p}.dense"),
vb,
)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
vb.pp("LayerNorm"),
)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
@ -324,9 +308,9 @@ struct BertAttention {
}
impl BertAttention {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?;
let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
Ok(Self {
self_attention,
self_output,
@ -347,13 +331,8 @@ struct BertIntermediate {
}
impl BertIntermediate {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(
config.hidden_size,
config.intermediate_size,
&format!("{p}.dense"),
vb,
)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
Ok(Self {
dense,
intermediate_act: config.hidden_act,
@ -375,18 +354,12 @@ struct BertOutput {
}
impl BertOutput {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(
config.intermediate_size,
config.hidden_size,
&format!("{p}.dense"),
vb,
)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
vb.pp("LayerNorm"),
)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
@ -411,10 +384,10 @@ struct BertLayer {
}
impl BertLayer {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?;
let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?;
let output = BertOutput::load(&format!("{p}.output"), vb, config)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let attention = BertAttention::load(vb.pp("attention"), config)?;
let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
let output = BertOutput::load(vb.pp("output"), config)?;
Ok(Self {
attention,
intermediate,
@ -441,12 +414,9 @@ struct BertEncoder {
}
impl BertEncoder {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| {
let p = format!("{p}.layer.{index}");
BertLayer::load(&p, vb, config)
})
.map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
Ok(BertEncoder { layers })
}
@ -469,17 +439,17 @@ struct BertModel {
}
impl BertModel {
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let (embeddings, encoder) = match (
BertEmbeddings::load("embeddings", vb, config),
BertEncoder::load("encoder", vb, config),
BertEmbeddings::load(vb.pp("embeddings"), config),
BertEncoder::load(vb.pp("encoder"), config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config),
BertEncoder::load(&format!("{model_type}.encoder"), vb, config),
BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
) {
(embeddings, encoder)
} else {
@ -493,7 +463,7 @@ impl BertModel {
Ok(Self {
embeddings,
encoder,
device: vb.device.clone(),
device: vb.device().clone(),
})
}
@ -576,7 +546,7 @@ impl Args {
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let model = BertModel::load(&vb, &config)?;
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}