mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
MusicGen var-store path cleanup. (#132)
This commit is contained in:
@ -8,10 +8,10 @@ const MAX_SEQ_LEN: usize = 5000;
|
||||
pub type VarBuilder<'a> = candle_nn::VarBuilder<'a>;
|
||||
pub type Linear = candle_nn::Linear;
|
||||
|
||||
pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
pub fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -20,17 +20,11 @@ pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder)
|
||||
|
||||
pub type LayerNorm = candle_nn::LayerNorm;
|
||||
|
||||
pub fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (
|
||||
vb.get(size, &format!("{p}.weight")),
|
||||
vb.get(size, &format!("{p}.bias")),
|
||||
) {
|
||||
pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (
|
||||
vb.get(size, &format!("{p}.gamma")),
|
||||
vb.get(size, &format!("{p}.beta")),
|
||||
) {
|
||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err.into());
|
||||
@ -58,13 +52,8 @@ impl Dropout {
|
||||
|
||||
pub type Embedding = candle_nn::Embedding;
|
||||
|
||||
pub fn embedding(
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
pub fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
@ -79,14 +68,13 @@ pub fn conv1d_weight_norm(
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), &format!("{p}.weight_g"))?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight_v"))?;
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
@ -95,11 +83,10 @@ pub fn conv1d(
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
let weight = vb.get((out_c, in_c, kernel_size), "weight")?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user