mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Simplify the parameters used by sum and sum_keepdim. (#165)
This commit is contained in:
@ -604,7 +604,7 @@ fn main() -> Result<()> {
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
||||
let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
|
@ -95,7 +95,7 @@ impl RmsNorm {
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
|
@ -70,7 +70,7 @@ pub fn conv1d_weight_norm(
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim(&[1, 2])?.sqrt()?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
|
@ -98,7 +98,7 @@ impl T5LayerNorm {
|
||||
let dtype = xs.dtype();
|
||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
||||
let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?;
|
||||
let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
|
||||
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
||||
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs = xs.to_dtype(dtype)?;
|
||||
|
Reference in New Issue
Block a user