Take as input slices of tensors as well as slices of &Tensors.

This commit is contained in:
laurent
2023-06-25 17:07:09 +01:00
parent 8b67f294e8
commit 334524e2c4
2 changed files with 26 additions and 17 deletions

View File

@ -427,7 +427,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng(); let mut rng = thread_rng();
for index in 0..args.sample_len { for index in 0..args.sample_len {
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &Device::Cpu)?; let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?;
let logits = llama.forward(&input, &freqs_cis)?; let logits = llama.forward(&input, &freqs_cis)?;
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?; let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?; let logits_v: Vec<f32> = prs.to_vec1()?;

View File

@ -24,6 +24,12 @@ pub struct Tensor_ {
is_variable: bool, is_variable: bool,
} }
impl AsRef<Tensor> for Tensor {
fn as_ref(&self) -> &Tensor {
self
}
}
// Tensors are refcounted so that cloning is cheap when building the op graph. // Tensors are refcounted so that cloning is cheap when building the op graph.
// Storages are also refcounted independently so that its possible to avoid // Storages are also refcounted independently so that its possible to avoid
// copying the storage for operations that only modify the shape or stride. // copying the storage for operations that only modify the shape or stride.
@ -802,19 +808,20 @@ impl Tensor {
} }
} }
pub fn cat(args: &[&Self], dim: usize) -> Result<Self> { pub fn cat<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
} }
let arg0 = args[0].as_ref();
if args.len() == 1 { if args.len() == 1 {
return Ok(args[0].clone()); return Ok(arg0.clone());
} }
let rank = args[0].rank(); let rank = arg0.rank();
if dim >= rank { if dim >= rank {
return Err(Error::UnexpectedNumberOfDims { return Err(Error::UnexpectedNumberOfDims {
expected: (dim + 1), expected: (dim + 1),
got: rank, got: rank,
shape: args[0].shape().clone(), shape: arg0.shape().clone(),
}); });
} }
if dim == 0 { if dim == 0 {
@ -824,29 +831,30 @@ impl Tensor {
// for dim != 0... // for dim != 0...
let args: Vec<Tensor> = args let args: Vec<Tensor> = args
.iter() .iter()
.map(|a| a.transpose(0, dim)) .map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let args: Vec<&Tensor> = args.iter().collect();
let cat = Self::cat0(&args)?; let cat = Self::cat0(&args)?;
cat.transpose(0, dim) cat.transpose(0, dim)
} }
} }
pub fn cat0(args: &[&Self]) -> Result<Self> { pub fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
} }
let arg0 = args[0].as_ref();
if args.len() == 1 { if args.len() == 1 {
return Ok(args[0].clone()); return Ok(arg0.clone());
} }
let rank = args[0].rank(); let rank = arg0.rank();
let device = args[0].device(); let device = arg0.device();
let dtype = args[0].dtype(); let dtype = arg0.dtype();
let first_dims = args[0].shape().dims(); let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec(); let mut cat_dims = first_dims.to_vec();
cat_dims[0] = 0; cat_dims[0] = 0;
let mut offsets = vec![0usize]; let mut offsets = vec![0usize];
for (arg_idx, arg) in args.iter().enumerate() { for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype { if arg.dtype() != dtype {
// TODO: Improve the error message. // TODO: Improve the error message.
return Err(Error::DTypeMismatchBinaryOp { return Err(Error::DTypeMismatchBinaryOp {
@ -864,7 +872,7 @@ impl Tensor {
}); });
} }
let mut mismatch = arg.rank() != rank; let mut mismatch = arg.rank() != rank;
for (dim_idx, (v1, v2)) in args[0] for (dim_idx, (v1, v2)) in arg0
.shape() .shape()
.dims() .dims()
.iter() .iter()
@ -883,7 +891,7 @@ impl Tensor {
if mismatch { if mismatch {
return Err(Error::ShapeMismatchCat { return Err(Error::ShapeMismatchCat {
dim: 0, // TODO: not the appropriate error message dim: 0, // TODO: not the appropriate error message
first_shape: args[0].shape().clone(), first_shape: arg0.shape().clone(),
n: arg_idx + 1, n: arg_idx + 1,
nth_shape: arg.shape().clone(), nth_shape: arg.shape().clone(),
}); });
@ -892,14 +900,15 @@ impl Tensor {
offsets.push(next_offset); offsets.push(next_offset);
} }
let shape = Shape::from(cat_dims); let shape = Shape::from(cat_dims);
let op = if args.iter().any(|arg| arg.track_op()) { let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
let args: Vec<Tensor> = args.iter().map(|&arg| arg.clone()).collect(); let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
Some(Op::Cat(args, 0)) Some(Op::Cat(args, 0))
} else { } else {
None None
}; };
let mut storage = device.zeros(&shape, dtype)?; let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) { for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage arg.storage
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?; .copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?;
} }