mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
This commit is contained in:
@ -87,7 +87,7 @@ impl LayerNorm {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
@ -262,7 +262,7 @@ impl BertEmbeddings {
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_bsize, seq_len) = input_ids.shape().r2()?;
|
||||
let (_bsize, seq_len) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
|
Reference in New Issue
Block a user