Rename the .r functions to .dims so as to be a bit more explicit. (#220)

This commit is contained in:
Laurent Mazare
2023-07-22 11:39:27 +02:00
committed by GitHub
parent 52c5d8c087
commit 43c7223292
18 changed files with 56 additions and 50 deletions

View File

@ -123,7 +123,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
}
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?;
let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?;
if seq_len > self.weights.dim(0)? {
self.weights = get_embedding(seq_len, self.embedding_dim)?
}
@ -170,7 +170,7 @@ impl MusicgenAttention {
kv_states: Option<&Tensor>,
attention_mask: &Tensor,
) -> Result<Tensor> {
let (b_sz, tgt_len, _) = xs.shape().r3()?;
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
let kv_states = kv_states.unwrap_or(xs);
@ -308,7 +308,7 @@ impl MusicgenDecoder {
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let dev = input_ids.device();
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?;
let b_sz = b_sz_times_codebooks / self.num_codebooks;
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
@ -352,7 +352,7 @@ impl MusicgenForCausalLM {
}
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len) = input_ids.shape().r2()?;
let (b_sz, seq_len) = input_ids.dims2()?;
let hidden_states = self.decoder.forward(input_ids)?;
let lm_logits = self
.lm_heads

View File

@ -338,7 +338,7 @@ impl T5Stack {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = input_embeds.shape().r2()?;
let (_b_sz, _seq_len) = input_embeds.dims2()?;
let mut hidden_states = self.dropout.forward(&input_embeds)?;
for block in self.block.iter() {