Merge pull request #82 from LaurentMazare/dim-index

Add a simpler way to specify the dim index for some ops.
This commit is contained in:
Laurent Mazare
2023-07-05 20:24:43 +01:00
committed by GitHub
7 changed files with 93 additions and 34 deletions

View File

@ -23,7 +23,7 @@ pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use layout::Layout;
pub use shape::Shape;
pub use shape::{Shape, D};
pub use storage::Storage;
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};

View File

@ -183,6 +183,45 @@ impl Shape {
}
}
pub trait Dim {
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
}
impl Dim for usize {
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
let dim = *self;
if dim >= shape.dims().len() {
Err(Error::DimOutOfRange {
shape: shape.clone(),
dim,
op,
})?
} else {
Ok(dim)
}
}
}
pub enum D {
Minus1,
Minus2,
}
impl Dim for D {
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
let rank = shape.rank();
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
_ => Err(Error::DimOutOfRange {
shape: shape.clone(),
dim: 42, // TODO: Have an adequate error
op,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,3 +1,4 @@
use crate::shape::Dim;
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::Arc;
@ -362,9 +363,9 @@ impl Tensor {
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
self.check_dim(dim, "narrow")?;
let dim = dim.to_index(self.shape(), "narrow")?;
if start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape().clone(),
@ -392,8 +393,8 @@ impl Tensor {
}
}
pub fn softmax(&self, dim: usize) -> Result<Self> {
self.check_dim(dim, "softmax")?;
pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "softmax")?;
// TODO: unify the two branches.
if self.device().is_cuda() {
// We do not have a cuda kernel for divide_by_sum_over_dim so split
@ -692,14 +693,22 @@ impl Tensor {
self.sum(&dims)
}
pub fn flatten(&self, start_dim: Option<usize>, end_dim: Option<usize>) -> Result<Tensor> {
fn flatten_<D1: Dim, D2: Dim>(
&self,
start_dim: Option<D1>,
end_dim: Option<D2>,
) -> Result<Tensor> {
if self.rank() == 0 {
self.reshape(1)
} else {
let start_dim = start_dim.unwrap_or(0);
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
self.check_dim(start_dim, "flatten")?;
self.check_dim(end_dim, "flatten")?;
let start_dim = match start_dim {
None => 0,
Some(dim) => dim.to_index(self.shape(), "flatten")?,
};
let end_dim = match end_dim {
None => self.rank() - 1,
Some(dim) => dim.to_index(self.shape(), "flatten")?,
};
if start_dim < end_dim {
let dims = self.dims();
let mut dst_dims = dims[..start_dim].to_vec();
@ -714,8 +723,20 @@ impl Tensor {
}
}
pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
self.flatten_(Some(start_dim), Some(end_dim))
}
pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
self.flatten_(None::<usize>, Some(end_dim))
}
pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
self.flatten_(Some(start_dim), None::<usize>)
}
pub fn flatten_all(&self) -> Result<Tensor> {
self.flatten(None, None)
self.flatten_(None::<usize>, None::<usize>)
}
pub fn get(&self, i: usize) -> Result<Tensor> {
@ -743,9 +764,9 @@ impl Tensor {
/// Returns a tensor that is a transposed version of the input, the given dimensions are
/// swapped.
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
self.check_dim(dim1, "transpose")?;
self.check_dim(dim2, "transpose")?;
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
let dim1 = dim1.to_index(self.shape(), "transpose")?;
let dim2 = dim2.to_index(self.shape(), "transpose")?;
let op = if self.track_op() {
Some(Op::Transpose(self.clone(), dim1, dim2))
} else {
@ -929,23 +950,23 @@ impl Tensor {
}
}
pub fn squeeze(&self, index: usize) -> Result<Self> {
pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
// The PyTorch semantics are to return the same tensor if the target dimension
// does not have a size of 1.
let dims = self.dims();
self.check_dim(index, "squeeze")?;
if dims[index] == 1 {
let dim = dim.to_index(self.shape(), "squeeze")?;
if dims[dim] == 1 {
let mut dims = dims.to_vec();
dims.remove(index);
dims.remove(dim);
self.reshape(dims)
} else {
Ok(self.clone())
}
}
pub fn unsqueeze(&self, index: usize) -> Result<Self> {
pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
let mut dims = self.dims().to_vec();
dims.insert(index, 1);
dims.insert(dim, 1);
self.reshape(dims)
}

View File

@ -386,12 +386,12 @@ impl BertSelfAttention {
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_probs = attention_scores.softmax(attention_scores.rank() - 1)?;
let attention_probs = attention_scores.softmax(candle::D::Minus1)?;
let attention_probs = self.dropout.forward(&attention_probs)?;
let context_layer = attention_probs.matmul(&value_layer)?;
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?;
let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
Ok(context_layer)
}
}

View File

@ -283,19 +283,18 @@ impl CausalSelfAttention {
dims.push(v / 2);
dims.push(2);
let x = x.reshape(dims)?;
let rank = x.rank();
let re_x = x.narrow(rank - 1, 0, 1)?;
let im_x = x.narrow(rank - 1, 1, 1)?;
let re_x = x.narrow(candle::D::Minus1, 0, 1)?;
let im_x = x.narrow(candle::D::Minus1, 1, 1)?;
let re_f = freqs_cis
.narrow(rank - 1, 0, 1)?
.narrow(candle::D::Minus1, 0, 1)?
.broadcast_as(re_x.shape())?;
let im_f = freqs_cis
.narrow(rank - 1, 1, 1)?
.narrow(candle::D::Minus1, 1, 1)?
.broadcast_as(im_x.shape())?;
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
let rope = rope.flatten(Some(rope.rank() - 2), None)?;
let rope = Tensor::cat(&[&re, &im], re.rank() - 1)?;
let rope = rope.flatten_from(candle::D::Minus2)?;
Ok(rope)
}
@ -339,7 +338,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(att.rank() - 1)?;
let att = att.softmax(candle::D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
@ -537,7 +536,7 @@ async fn main() -> Result<()> {
let next_token = if let Some(temperature) = args.temperature {
println!("Sampling with temperature {temperature:?}");
let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?;
let prs = (&logits / temperature)?.softmax(candle::D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;

View File

@ -109,7 +109,7 @@ impl Decode {
};
tokens.push(next_token);
let prob = logits
.softmax(logits.rank() - 1)?
.softmax(candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {

View File

@ -342,8 +342,8 @@ impl MultiHeadAttention {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = qk.softmax(qk.rank() - 1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
let w = qk.softmax(candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
Ok(wv)
}
}