mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Simplify the parameters used by sum and sum_keepdim. (#165)
This commit is contained in:
@ -70,7 +70,7 @@ pub fn conv1d_weight_norm(
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim(&[1, 2])?.sqrt()?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
|
Reference in New Issue
Block a user