Rename the .r functions to .dims so as to be a bit more explicit. (#220)

This commit is contained in:
Laurent Mazare
2023-07-22 11:39:27 +02:00
committed by GitHub
parent 52c5d8c087
commit 43c7223292
18 changed files with 56 additions and 50 deletions

View File

@ -41,7 +41,7 @@ impl Conv1d {
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.shape().r1()?;
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?;
Ok(x.broadcast_add(&bias)?)
}

View File

@ -49,7 +49,7 @@ impl LayerNorm {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;