mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
This commit is contained in:
@ -182,7 +182,7 @@ impl FalconRotaryEmbedding {
|
||||
key: &Tensor,
|
||||
past_kv_len: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
|
||||
let (_batch, seq_len, _head_dim) = query.dims3()?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
|
||||
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
||||
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
||||
@ -245,7 +245,7 @@ impl FalconAttention {
|
||||
}
|
||||
|
||||
fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let (b_sz, seq_len, _) = fused_qkv.shape().r3()?;
|
||||
let (b_sz, seq_len, _) = fused_qkv.dims3()?;
|
||||
if !self.multi_query {
|
||||
let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
|
||||
let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
|
||||
@ -267,7 +267,7 @@ impl FalconAttention {
|
||||
let fused_qkv = self.query_key_value.forward(x)?;
|
||||
let head_dim = self.head_dim;
|
||||
let (query, key, value) = self.split_heads(&fused_qkv)?;
|
||||
let (b_sz, seq_len, _, _) = query.shape().r4()?;
|
||||
let (b_sz, seq_len, _, _) = query.dims4()?;
|
||||
let query = query
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
|
||||
@ -465,7 +465,7 @@ impl Falcon {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
|
||||
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
|
||||
Some((k, _)) => k.dim(1)?,
|
||||
|
Reference in New Issue
Block a user