mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
VarBuilder path creation (#131)
* Use a struct for the safetensor+routing. * Group the path and the var-builder together. * Fix for the empty path case.
This commit is contained in:
@ -109,14 +109,14 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
|
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
let weight = vb.get((size2, size1), "weight")?;
|
||||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
let bias = vb.get(size2, "bias")?;
|
||||||
Ok(Linear::new(weight, Some(bias)))
|
Ok(Linear::new(weight, Some(bias)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,17 +135,11 @@ impl Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
let (weight, bias) = match (
|
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||||
vb.get(size, &format!("{p}.weight")),
|
|
||||||
vb.get(size, &format!("{p}.bias")),
|
|
||||||
) {
|
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
if let (Ok(weight), Ok(bias)) = (
|
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||||
vb.get(size, &format!("{p}.gamma")),
|
|
||||||
vb.get(size, &format!("{p}.beta")),
|
|
||||||
) {
|
|
||||||
(weight, bias)
|
(weight, bias)
|
||||||
} else {
|
} else {
|
||||||
return Err(err.into());
|
return Err(err.into());
|
||||||
@ -167,33 +161,29 @@ struct BertEmbeddings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertEmbeddings {
|
impl BertEmbeddings {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let word_embeddings = embedding(
|
let word_embeddings = embedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
&format!("{p}.word_embeddings"),
|
vb.pp("word_embeddings"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let position_embeddings = embedding(
|
let position_embeddings = embedding(
|
||||||
config.max_position_embeddings,
|
config.max_position_embeddings,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
&format!("{p}.position_embeddings"),
|
vb.pp("position_embeddings"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let token_type_embeddings = embedding(
|
let token_type_embeddings = embedding(
|
||||||
config.type_vocab_size,
|
config.type_vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
&format!("{p}.token_type_embeddings"),
|
vb.pp("token_type_embeddings"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let layer_norm = layer_norm(
|
let layer_norm = layer_norm(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.layer_norm_eps,
|
config.layer_norm_eps,
|
||||||
&format!("{p}.LayerNorm"),
|
vb.pp("LayerNorm"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
|
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
|
||||||
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
|
let position_ids = Tensor::new(&position_ids[..], vb.device())?.unsqueeze(0)?;
|
||||||
let token_type_ids = position_ids.zeros_like()?;
|
let token_type_ids = position_ids.zeros_like()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
word_embeddings,
|
word_embeddings,
|
||||||
@ -233,14 +223,14 @@ struct BertSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertSelfAttention {
|
impl BertSelfAttention {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||||
let all_head_size = config.num_attention_heads * attention_head_size;
|
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
let hidden_size = config.hidden_size;
|
let hidden_size = config.hidden_size;
|
||||||
let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
|
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
|
||||||
let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
|
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
|
||||||
let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
|
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -289,18 +279,12 @@ struct BertSelfOutput {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertSelfOutput {
|
impl BertSelfOutput {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let dense = linear(
|
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
|
||||||
config.hidden_size,
|
|
||||||
config.hidden_size,
|
|
||||||
&format!("{p}.dense"),
|
|
||||||
vb,
|
|
||||||
)?;
|
|
||||||
let layer_norm = layer_norm(
|
let layer_norm = layer_norm(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.layer_norm_eps,
|
config.layer_norm_eps,
|
||||||
&format!("{p}.LayerNorm"),
|
vb.pp("LayerNorm"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -324,9 +308,9 @@ struct BertAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertAttention {
|
impl BertAttention {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?;
|
let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
|
||||||
let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?;
|
let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attention,
|
self_attention,
|
||||||
self_output,
|
self_output,
|
||||||
@ -347,13 +331,8 @@ struct BertIntermediate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertIntermediate {
|
impl BertIntermediate {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let dense = linear(
|
let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
|
||||||
config.hidden_size,
|
|
||||||
config.intermediate_size,
|
|
||||||
&format!("{p}.dense"),
|
|
||||||
vb,
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
dense,
|
dense,
|
||||||
intermediate_act: config.hidden_act,
|
intermediate_act: config.hidden_act,
|
||||||
@ -375,18 +354,12 @@ struct BertOutput {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertOutput {
|
impl BertOutput {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let dense = linear(
|
let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
|
||||||
config.intermediate_size,
|
|
||||||
config.hidden_size,
|
|
||||||
&format!("{p}.dense"),
|
|
||||||
vb,
|
|
||||||
)?;
|
|
||||||
let layer_norm = layer_norm(
|
let layer_norm = layer_norm(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.layer_norm_eps,
|
config.layer_norm_eps,
|
||||||
&format!("{p}.LayerNorm"),
|
vb.pp("LayerNorm"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -411,10 +384,10 @@ struct BertLayer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertLayer {
|
impl BertLayer {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?;
|
let attention = BertAttention::load(vb.pp("attention"), config)?;
|
||||||
let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?;
|
let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
|
||||||
let output = BertOutput::load(&format!("{p}.output"), vb, config)?;
|
let output = BertOutput::load(vb.pp("output"), config)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
attention,
|
attention,
|
||||||
intermediate,
|
intermediate,
|
||||||
@ -441,12 +414,9 @@ struct BertEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertEncoder {
|
impl BertEncoder {
|
||||||
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let layers = (0..config.num_hidden_layers)
|
let layers = (0..config.num_hidden_layers)
|
||||||
.map(|index| {
|
.map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
|
||||||
let p = format!("{p}.layer.{index}");
|
|
||||||
BertLayer::load(&p, vb, config)
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(BertEncoder { layers })
|
Ok(BertEncoder { layers })
|
||||||
}
|
}
|
||||||
@ -469,17 +439,17 @@ struct BertModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertModel {
|
impl BertModel {
|
||||||
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
let (embeddings, encoder) = match (
|
let (embeddings, encoder) = match (
|
||||||
BertEmbeddings::load("embeddings", vb, config),
|
BertEmbeddings::load(vb.pp("embeddings"), config),
|
||||||
BertEncoder::load("encoder", vb, config),
|
BertEncoder::load(vb.pp("encoder"), config),
|
||||||
) {
|
) {
|
||||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
if let Some(model_type) = &config.model_type {
|
if let Some(model_type) = &config.model_type {
|
||||||
if let (Ok(embeddings), Ok(encoder)) = (
|
if let (Ok(embeddings), Ok(encoder)) = (
|
||||||
BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config),
|
BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||||
BertEncoder::load(&format!("{model_type}.encoder"), vb, config),
|
BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
|
||||||
) {
|
) {
|
||||||
(embeddings, encoder)
|
(embeddings, encoder)
|
||||||
} else {
|
} else {
|
||||||
@ -493,7 +463,7 @@ impl BertModel {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
embeddings,
|
embeddings,
|
||||||
encoder,
|
encoder,
|
||||||
device: vb.device.clone(),
|
device: vb.device().clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -576,7 +546,7 @@ impl Args {
|
|||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||||
let model = BertModel::load(&vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -169,7 +169,7 @@ fn main() -> Result<()> {
|
|||||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
|
||||||
let config = Config::falcon7b();
|
let config = Config::falcon7b();
|
||||||
config.validate()?;
|
config.validate()?;
|
||||||
let model = Falcon::load(&vb, config)?;
|
let model = Falcon::load(vb, config)?;
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
||||||
|
@ -4,27 +4,21 @@ use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
|||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 5000;
|
const MAX_SEQ_LEN: usize = 5000;
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
let weight = vb.get((size2, size1), "weight")?;
|
||||||
let bias = if bias {
|
let bias = if bias {
|
||||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
Some(vb.get(size2, "bias")?)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
Ok(Linear::new(weight, bias))
|
Ok(Linear::new(weight, bias))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
let (weight, bias) = match (
|
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||||
vb.get(size, &format!("{p}.weight")),
|
|
||||||
vb.get(size, &format!("{p}.bias")),
|
|
||||||
) {
|
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
if let (Ok(weight), Ok(bias)) = (
|
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||||
vb.get(size, &format!("{p}.gamma")),
|
|
||||||
vb.get(size, &format!("{p}.beta")),
|
|
||||||
) {
|
|
||||||
(weight, bias)
|
(weight, bias)
|
||||||
} else {
|
} else {
|
||||||
return Err(err.into());
|
return Err(err.into());
|
||||||
@ -50,8 +44,8 @@ impl Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
|
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,14 +158,14 @@ struct FalconRotaryEmbedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FalconRotaryEmbedding {
|
impl FalconRotaryEmbedding {
|
||||||
fn load(vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(device: &Device, cfg: &Config) -> Result<Self> {
|
||||||
let head_dim = cfg.head_dim();
|
let head_dim = cfg.head_dim();
|
||||||
let inv_freq: Vec<_> = (0..head_dim)
|
let inv_freq: Vec<_> = (0..head_dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inv_freq: Tensor::new(inv_freq.as_slice(), &vb.device)?,
|
inv_freq: Tensor::new(inv_freq.as_slice(), device)?,
|
||||||
cache: None,
|
cache: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -237,9 +231,9 @@ struct FalconAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FalconAttention {
|
impl FalconAttention {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let maybe_rotary = if cfg.rotary() {
|
let maybe_rotary = if cfg.rotary() {
|
||||||
let rotary = FalconRotaryEmbedding::load(vb, cfg)?;
|
let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?;
|
||||||
Some(rotary)
|
Some(rotary)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -251,20 +245,8 @@ impl FalconAttention {
|
|||||||
} else {
|
} else {
|
||||||
3 * hidden_size
|
3 * hidden_size
|
||||||
};
|
};
|
||||||
let query_key_value = linear(
|
let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?;
|
||||||
hidden_size,
|
let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?;
|
||||||
qkv_out_dim,
|
|
||||||
cfg.bias,
|
|
||||||
&format!("{p}.query_key_value"),
|
|
||||||
vb,
|
|
||||||
)?;
|
|
||||||
let dense = linear(
|
|
||||||
hidden_size,
|
|
||||||
hidden_size,
|
|
||||||
cfg.bias,
|
|
||||||
&format!("{p}.dense"),
|
|
||||||
vb,
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
query_key_value,
|
query_key_value,
|
||||||
dense,
|
dense,
|
||||||
@ -367,11 +349,11 @@ struct FalconMlp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FalconMlp {
|
impl FalconMlp {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let b = cfg.bias;
|
let b = cfg.bias;
|
||||||
let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
|
let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?;
|
||||||
let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
|
let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?;
|
||||||
let dropout = Dropout::new(cfg.hidden_dropout);
|
let dropout = Dropout::new(cfg.hidden_dropout);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
dense_h_to_4h,
|
dense_h_to_4h,
|
||||||
@ -397,23 +379,21 @@ struct FalconDecoderLayer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FalconDecoderLayer {
|
impl FalconDecoderLayer {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?;
|
let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?;
|
||||||
let inp_layernorm = layer_norm(
|
let inp_layernorm = layer_norm(
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.layer_norm_epsilon,
|
cfg.layer_norm_epsilon,
|
||||||
&format!("{p}.input_layernorm"),
|
vb.pp("input_layernorm"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let self_attention = FalconAttention::load(&format!("{p}.self_attention"), vb, cfg)?;
|
let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?;
|
||||||
let post_attention_layernorm = if cfg.parallel_attn {
|
let post_attention_layernorm = if cfg.parallel_attn {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let ln = layer_norm(
|
let ln = layer_norm(
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.layer_norm_epsilon,
|
cfg.layer_norm_epsilon,
|
||||||
&format!("{p}.post_attention_layernorm"),
|
vb.pp("post_attention_layernorm"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
Some(ln)
|
Some(ln)
|
||||||
};
|
};
|
||||||
@ -480,23 +460,21 @@ impl Falcon {
|
|||||||
&self.config
|
&self.config
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: &VarBuilder, cfg: Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||||
let word_embeddings = embedding(
|
let word_embeddings = embedding(
|
||||||
cfg.vocab_size,
|
cfg.vocab_size,
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
"transformer.word_embeddings",
|
vb.pp("transformer.word_embeddings"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let blocks = (0..cfg.num_hidden_layers)
|
let blocks = (0..cfg.num_hidden_layers)
|
||||||
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
|
.map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln_f = layer_norm(
|
let ln_f = layer_norm(
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.layer_norm_epsilon,
|
cfg.layer_norm_epsilon,
|
||||||
"transformer.ln_f",
|
vb.pp("transformer.ln_f"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
word_embeddings,
|
word_embeddings,
|
||||||
blocks,
|
blocks,
|
||||||
|
@ -38,19 +38,19 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
|
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
let weight = vb.get((size2, size1), "weight")?;
|
||||||
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
let bias = vb.get(size2, "bias")?;
|
||||||
Ok(Linear::new(weight, Some(bias)))
|
Ok(Linear::new(weight, Some(bias)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
let weight = vb.get((size2, size1), "weight")?;
|
||||||
Ok(Linear::new(weight, None))
|
Ok(Linear::new(weight, None))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,14 +59,10 @@ fn conv1d(
|
|||||||
out_channels: usize,
|
out_channels: usize,
|
||||||
kernel_size: usize,
|
kernel_size: usize,
|
||||||
config: Conv1dConfig,
|
config: Conv1dConfig,
|
||||||
p: &str,
|
vb: VarBuilder,
|
||||||
vb: &VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
) -> Result<Conv1d> {
|
||||||
let weight = vb.get(
|
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
|
||||||
(out_channels, in_channels, kernel_size),
|
let bias = vb.get(out_channels, "bias")?;
|
||||||
&format!("{p}.weight"),
|
|
||||||
)?;
|
|
||||||
let bias = vb.get(out_channels, &format!("{p}.bias"))?;
|
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,13 +71,9 @@ fn conv1d_no_bias(
|
|||||||
out_channels: usize,
|
out_channels: usize,
|
||||||
kernel_size: usize,
|
kernel_size: usize,
|
||||||
config: Conv1dConfig,
|
config: Conv1dConfig,
|
||||||
p: &str,
|
vb: VarBuilder,
|
||||||
vb: &VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
) -> Result<Conv1d> {
|
||||||
let weight = vb.get(
|
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
|
||||||
(out_channels, in_channels, kernel_size),
|
|
||||||
&format!("{p}.weight"),
|
|
||||||
)?;
|
|
||||||
Ok(Conv1d::new(weight, None, config))
|
Ok(Conv1d::new(weight, None, config))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,9 +92,9 @@ impl Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
let weight = vb.get(size, "weight")?;
|
||||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
let bias = vb.get(size, "bias")?;
|
||||||
Ok(LayerNorm::new(weight, bias, 1e-5))
|
Ok(LayerNorm::new(weight, bias, 1e-5))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,11 +108,11 @@ struct MultiHeadAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MultiHeadAttention {
|
impl MultiHeadAttention {
|
||||||
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?;
|
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||||
let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?;
|
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||||
let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
|
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||||
let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?;
|
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -179,21 +171,20 @@ struct ResidualAttentionBlock {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ResidualAttentionBlock {
|
impl ResidualAttentionBlock {
|
||||||
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
|
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
||||||
let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
|
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
||||||
let cross_attn = if ca {
|
let cross_attn = if ca {
|
||||||
let cross_attn =
|
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
|
||||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
|
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
|
||||||
let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
|
||||||
Some((cross_attn, cross_attn_ln))
|
Some((cross_attn, cross_attn_ln))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let n_mlp = n_state * 4;
|
let n_mlp = n_state * 4;
|
||||||
let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
|
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
|
||||||
let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
|
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
|
||||||
let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?;
|
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
attn,
|
attn,
|
||||||
attn_ln,
|
attn_ln,
|
||||||
@ -245,7 +236,7 @@ pub struct AudioEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AudioEncoder {
|
impl AudioEncoder {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let n_state = cfg.d_model;
|
let n_state = cfg.d_model;
|
||||||
let n_head = cfg.encoder_attention_heads;
|
let n_head = cfg.encoder_attention_heads;
|
||||||
let n_ctx = cfg.max_source_positions;
|
let n_ctx = cfg.max_source_positions;
|
||||||
@ -257,22 +248,15 @@ impl AudioEncoder {
|
|||||||
padding: 1,
|
padding: 1,
|
||||||
stride: 2,
|
stride: 2,
|
||||||
};
|
};
|
||||||
let conv1 = conv1d(
|
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||||
cfg.num_mel_bins,
|
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||||
n_state,
|
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
||||||
3,
|
|
||||||
cfg1,
|
|
||||||
&format!("{p}.conv1"),
|
|
||||||
vb,
|
|
||||||
)?;
|
|
||||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
|
||||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
|
|
||||||
let blocks = (0..cfg.encoder_layers)
|
let blocks = (0..cfg.encoder_layers)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
|
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
|
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conv1,
|
conv1,
|
||||||
conv2,
|
conv2,
|
||||||
@ -306,23 +290,22 @@ pub struct TextDecoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TextDecoder {
|
impl TextDecoder {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let n_state = cfg.d_model;
|
let n_state = cfg.d_model;
|
||||||
let n_head = cfg.decoder_attention_heads;
|
let n_head = cfg.decoder_attention_heads;
|
||||||
let n_ctx = cfg.max_target_positions;
|
let n_ctx = cfg.max_target_positions;
|
||||||
let token_embedding = embedding(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
|
let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
|
||||||
let positional_embedding =
|
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
|
||||||
vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?;
|
|
||||||
let blocks = (0..cfg.decoder_layers)
|
let blocks = (0..cfg.decoder_layers)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
|
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
|
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||||
let mask: Vec<_> = (0..n_ctx)
|
let mask: Vec<_> = (0..n_ctx)
|
||||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), &vb.device)?;
|
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
token_embedding,
|
token_embedding,
|
||||||
@ -361,8 +344,8 @@ pub struct Whisper {
|
|||||||
|
|
||||||
impl Whisper {
|
impl Whisper {
|
||||||
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
||||||
let encoder = AudioEncoder::load("model.encoder", vb, &config)?;
|
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
|
||||||
let decoder = TextDecoder::load("model.decoder", vb, &config)?;
|
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
encoder,
|
encoder,
|
||||||
decoder,
|
decoder,
|
||||||
|
@ -1,53 +1,118 @@
|
|||||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub struct VarBuilder<'a> {
|
struct SafeTensorWithRouting<'a> {
|
||||||
safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>,
|
routing: HashMap<String, usize>,
|
||||||
|
safetensors: Vec<SafeTensors<'a>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TensorData<'a> {
|
||||||
|
// TODO: Make this part generic, probably via some Box<dyn> to avoid too much generics.
|
||||||
|
safetensors: Option<SafeTensorWithRouting<'a>>,
|
||||||
pub dtype: DType,
|
pub dtype: DType,
|
||||||
pub device: Device,
|
pub device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
impl<'a> TensorData<'a> {
|
||||||
pub fn from_safetensors(
|
fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, device: &Device) -> Self {
|
||||||
safetensors: Vec<SafeTensors<'a>>,
|
|
||||||
dtype: DType,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let mut routing = HashMap::new();
|
let mut routing = HashMap::new();
|
||||||
for (index, sf) in safetensors.iter().enumerate() {
|
for (index, sf) in safetensors.iter().enumerate() {
|
||||||
for k in sf.names() {
|
for k in sf.names() {
|
||||||
routing.insert(k.to_string(), index);
|
routing.insert(k.to_string(), index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let safetensors = SafeTensorWithRouting {
|
||||||
|
routing,
|
||||||
|
safetensors,
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
safetensors: Some((routing, safetensors)),
|
safetensors: Some(safetensors),
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn zeros(dtype: DType, device: Device) -> Self {
|
fn zeros(dtype: DType, device: &Device) -> Self {
|
||||||
Self {
|
Self {
|
||||||
safetensors: None,
|
safetensors: None,
|
||||||
device,
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct VarBuilder<'a> {
|
||||||
|
data: Arc<TensorData<'a>>,
|
||||||
|
path: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> VarBuilder<'a> {
|
||||||
|
/// Create a `VarBuilder` accessing data frome the safetensors storage. The initial path is
|
||||||
|
/// set to the root path and sub-paths can be created via the `push_prefix` method.
|
||||||
|
pub fn from_safetensors(st: Vec<SafeTensors<'a>>, dtype: DType, device: &Device) -> Self {
|
||||||
|
let data = TensorData::from_safetensors(st, dtype, device);
|
||||||
|
Self {
|
||||||
|
data: Arc::new(data),
|
||||||
|
path: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn zeros(dtype: DType, device: &Device) -> Self {
|
||||||
|
let data = TensorData::zeros(dtype, device);
|
||||||
|
Self {
|
||||||
|
data: Arc::new(data),
|
||||||
|
path: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push_prefix(&self, s: &str) -> Self {
|
||||||
|
let mut path = self.path.clone();
|
||||||
|
path.push(s.to_string());
|
||||||
|
Self {
|
||||||
|
data: self.data.clone(),
|
||||||
|
path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Short alias for `push_prefix`.
|
||||||
|
pub fn pp(&self, s: &str) -> Self {
|
||||||
|
self.push_prefix(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &Device {
|
||||||
|
&self.data.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> DType {
|
||||||
|
self.data.dtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> VarBuilder<'a> {
|
||||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
|
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
|
||||||
|
let data = self.data.as_ref();
|
||||||
let s: Shape = s.into();
|
let s: Shape = s.into();
|
||||||
match &self.safetensors {
|
match &self.data.safetensors {
|
||||||
None => Tensor::zeros(s, self.dtype, &self.device),
|
None => Tensor::zeros(s, data.dtype, &data.device),
|
||||||
Some((routing, safetensors)) => {
|
Some(SafeTensorWithRouting {
|
||||||
|
routing,
|
||||||
|
safetensors,
|
||||||
|
}) => {
|
||||||
|
let path = if self.path.is_empty() {
|
||||||
|
tensor_name.to_string()
|
||||||
|
} else {
|
||||||
|
[&self.path.join("."), tensor_name].join(".")
|
||||||
|
};
|
||||||
// Unwrap or 0 just to let the proper error flow.
|
// Unwrap or 0 just to let the proper error flow.
|
||||||
let index = routing.get(tensor_name).unwrap_or(&0);
|
let index = routing.get(&path).unwrap_or(&0);
|
||||||
let tensor = safetensors[*index]
|
let tensor = safetensors[*index]
|
||||||
.tensor(tensor_name, &self.device)?
|
.tensor(&path, &data.device)?
|
||||||
.to_dtype(self.dtype)?;
|
.to_dtype(data.dtype)?;
|
||||||
if *tensor.shape() != s {
|
if *tensor.shape() != s {
|
||||||
let msg = format!("shape mismatch for {tensor_name}");
|
|
||||||
Err(candle::Error::UnexpectedShape {
|
Err(candle::Error::UnexpectedShape {
|
||||||
msg,
|
msg: format!("shape mismatch for {path}"),
|
||||||
expected: s,
|
expected: s,
|
||||||
got: tensor.shape().clone(),
|
got: tensor.shape().clone(),
|
||||||
})?
|
})?
|
||||||
|
Reference in New Issue
Block a user