Support dim indexes in cat.

This commit is contained in:
laurent
2023-07-05 20:39:08 +01:00
parent fc2ffcc72b
commit e2bfbcb79c
2 changed files with 14 additions and 13 deletions

View File

@ -970,10 +970,11 @@ impl Tensor {
self.reshape(dims) self.reshape(dims)
} }
pub fn stack<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> { pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
} }
let dim = dim.to_index(args[0].as_ref().shape(), "stack")?;
let args = args let args = args
.iter() .iter()
.map(|t| t.as_ref().unsqueeze(dim)) .map(|t| t.as_ref().unsqueeze(dim))
@ -981,7 +982,7 @@ impl Tensor {
Self::cat(&args, dim) Self::cat(&args, dim)
} }
pub fn cat<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> { pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
} }
@ -989,6 +990,7 @@ impl Tensor {
if args.len() == 1 { if args.len() == 1 {
return Ok(arg0.clone()); return Ok(arg0.clone());
} }
let dim = dim.to_index(arg0.shape(), "cat")?;
for arg in args { for arg in args {
arg.as_ref().check_dim(dim, "cat")?; arg.as_ref().check_dim(dim, "cat")?;
} }

View File

@ -15,7 +15,7 @@ use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
use rand::{distributions::Distribution, SeedableRng}; use rand::{distributions::Distribution, SeedableRng};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor, D};
use candle_hub::{api::Api, Repo, RepoType}; use candle_hub::{api::Api, Repo, RepoType};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@ -283,18 +283,18 @@ impl CausalSelfAttention {
dims.push(v / 2); dims.push(v / 2);
dims.push(2); dims.push(2);
let x = x.reshape(dims)?; let x = x.reshape(dims)?;
let re_x = x.narrow(candle::D::Minus1, 0, 1)?; let re_x = x.narrow(D::Minus1, 0, 1)?;
let im_x = x.narrow(candle::D::Minus1, 1, 1)?; let im_x = x.narrow(D::Minus1, 1, 1)?;
let re_f = freqs_cis let re_f = freqs_cis
.narrow(candle::D::Minus1, 0, 1)? .narrow(D::Minus1, 0, 1)?
.broadcast_as(re_x.shape())?; .broadcast_as(re_x.shape())?;
let im_f = freqs_cis let im_f = freqs_cis
.narrow(candle::D::Minus1, 1, 1)? .narrow(D::Minus1, 1, 1)?
.broadcast_as(im_x.shape())?; .broadcast_as(im_x.shape())?;
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let rope = Tensor::cat(&[&re, &im], re.rank() - 1)?; let rope = Tensor::cat(&[&re, &im], D::Minus1)?;
let rope = rope.flatten_from(candle::D::Minus2)?; let rope = rope.flatten_from(D::Minus2)?;
Ok(rope) Ok(rope)
} }
@ -338,7 +338,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(candle::D::Minus1)?; let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?; let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(0, 1)?.reshape(&[t, c])?; let y = y.transpose(0, 1)?.reshape(&[t, c])?;
@ -424,8 +424,7 @@ fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1]; let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1];
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?; let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?; let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
let last_dim = idx_theta_cos.rank() - 1; Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], D::Minus1)?)
Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim)?)
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -536,7 +535,7 @@ async fn main() -> Result<()> {
let next_token = if let Some(temperature) = args.temperature { let next_token = if let Some(temperature) = args.temperature {
println!("Sampling with temperature {temperature:?}"); println!("Sampling with temperature {temperature:?}");
let prs = (&logits / temperature)?.softmax(candle::D::Minus1)?; let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?; let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?;