Simplify the parameters used by sum and sum_keepdim. (#165)

This commit is contained in:
Laurent Mazare
2023-07-14 08:22:08 +01:00
committed by GitHub
parent 2bfa791336
commit a2f72edc0d
13 changed files with 179 additions and 98 deletions

View File

@ -70,7 +70,7 @@ pub fn conv1d_weight_norm(
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim(&[1, 2])?.sqrt()?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))

View File

@ -98,7 +98,7 @@ impl T5LayerNorm {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
let xs2_f32 = (&xs_f32 * &xs_f32)?;
let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?;
let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;