mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
This commit is contained in:
@ -161,7 +161,7 @@ fn main() -> Result<()> {
|
||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
let mut similarities = vec![];
|
||||
|
@ -87,7 +87,7 @@ impl LayerNorm {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
@ -262,7 +262,7 @@ impl BertEmbeddings {
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_bsize, seq_len) = input_ids.shape().r2()?;
|
||||
let (_bsize, seq_len) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
|
@ -182,7 +182,7 @@ impl FalconRotaryEmbedding {
|
||||
key: &Tensor,
|
||||
past_kv_len: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
|
||||
let (_batch, seq_len, _head_dim) = query.dims3()?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
|
||||
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
||||
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
||||
@ -245,7 +245,7 @@ impl FalconAttention {
|
||||
}
|
||||
|
||||
fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let (b_sz, seq_len, _) = fused_qkv.shape().r3()?;
|
||||
let (b_sz, seq_len, _) = fused_qkv.dims3()?;
|
||||
if !self.multi_query {
|
||||
let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
|
||||
let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
|
||||
@ -267,7 +267,7 @@ impl FalconAttention {
|
||||
let fused_qkv = self.query_key_value.forward(x)?;
|
||||
let head_dim = self.head_dim;
|
||||
let (query, key, value) = self.split_heads(&fused_qkv)?;
|
||||
let (b_sz, seq_len, _, _) = query.shape().r4()?;
|
||||
let (b_sz, seq_len, _, _) = query.dims4()?;
|
||||
let query = query
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
|
||||
@ -465,7 +465,7 @@ impl Falcon {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
|
||||
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
|
||||
Some((k, _)) => k.dim(1)?,
|
||||
|
@ -116,11 +116,11 @@ impl RmsNorm {
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
let size = self.scale.dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
@ -144,7 +144,7 @@ struct CausalSelfAttention {
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, _, seq_len, n_embd) = x.shape().r4()?;
|
||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
@ -158,7 +158,7 @@ impl CausalSelfAttention {
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let (b_sz, seq_len, n_embd) = x.shape().r3()?;
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
@ -219,7 +219,7 @@ impl CausalSelfAttention {
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?;
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
@ -345,7 +345,7 @@ impl Llama {
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.shape().r2()?;
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
|
@ -123,7 +123,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
|
||||
}
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?;
|
||||
let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?;
|
||||
if seq_len > self.weights.dim(0)? {
|
||||
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
||||
}
|
||||
@ -170,7 +170,7 @@ impl MusicgenAttention {
|
||||
kv_states: Option<&Tensor>,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, _) = xs.shape().r3()?;
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
|
||||
|
||||
let kv_states = kv_states.unwrap_or(xs);
|
||||
@ -308,7 +308,7 @@ impl MusicgenDecoder {
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let dev = input_ids.device();
|
||||
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
|
||||
let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?;
|
||||
let b_sz = b_sz_times_codebooks / self.num_codebooks;
|
||||
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
|
||||
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
|
||||
@ -352,7 +352,7 @@ impl MusicgenForCausalLM {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
let hidden_states = self.decoder.forward(input_ids)?;
|
||||
let lm_logits = self
|
||||
.lm_heads
|
||||
|
@ -338,7 +338,7 @@ impl T5Stack {
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let (_b_sz, _seq_len) = input_embeds.shape().r2()?;
|
||||
let (_b_sz, _seq_len) = input_embeds.dims2()?;
|
||||
|
||||
let mut hidden_states = self.dropout.forward(&input_embeds)?;
|
||||
for block in self.block.iter() {
|
||||
|
@ -52,7 +52,7 @@ pub fn main() -> Result<()> {
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.shape().r1()? as f32;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
|
@ -127,7 +127,7 @@ impl Decoder {
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
let (seq_len, _) = logits.shape().r2()?;
|
||||
let (seq_len, _) = logits.dims2()?;
|
||||
let logits = logits
|
||||
.get(seq_len - 1)?
|
||||
.broadcast_add(&self.suppress_tokens)?;
|
||||
@ -195,7 +195,7 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
|
||||
let (_, _, content_frames) = mel.shape().r3()?;
|
||||
let (_, _, content_frames) = mel.dims3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
while seek < content_frames {
|
||||
|
@ -132,7 +132,7 @@ impl MultiHeadAttention {
|
||||
}
|
||||
|
||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = x.shape().r3()?;
|
||||
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||
}
|
||||
@ -144,7 +144,7 @@ impl MultiHeadAttention {
|
||||
v: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (_, n_ctx, n_state) = q.shape().r3()?;
|
||||
let (_, n_ctx, n_state) = q.dims3()?;
|
||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
@ -270,7 +270,7 @@ impl AudioEncoder {
|
||||
let x = self.conv1.forward(x)?.gelu()?;
|
||||
let x = self.conv2.forward(&x)?.gelu()?;
|
||||
let x = x.transpose(1, 2)?;
|
||||
let (_bsize, seq_len, _hidden) = x.shape().r3()?;
|
||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
|
Reference in New Issue
Block a user