mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
This commit is contained in:
@ -123,7 +123,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
|
||||
}
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?;
|
||||
let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?;
|
||||
if seq_len > self.weights.dim(0)? {
|
||||
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
||||
}
|
||||
@ -170,7 +170,7 @@ impl MusicgenAttention {
|
||||
kv_states: Option<&Tensor>,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, _) = xs.shape().r3()?;
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
|
||||
|
||||
let kv_states = kv_states.unwrap_or(xs);
|
||||
@ -308,7 +308,7 @@ impl MusicgenDecoder {
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let dev = input_ids.device();
|
||||
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
|
||||
let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?;
|
||||
let b_sz = b_sz_times_codebooks / self.num_codebooks;
|
||||
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
|
||||
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
|
||||
@ -352,7 +352,7 @@ impl MusicgenForCausalLM {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
let hidden_states = self.decoder.forward(input_ids)?;
|
||||
let lm_logits = self
|
||||
.lm_heads
|
||||
|
Reference in New Issue
Block a user