diff --git a/examples/llama/main.rs b/examples/llama/main.rs index 6355d10f..32ad71bc 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -427,7 +427,7 @@ fn main() -> Result<()> { let mut rng = thread_rng(); for index in 0..args.sample_len { let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; - let input = Tensor::new(ctxt, &Device::Cpu)?; + let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?; let logits = llama.forward(&input, &freqs_cis)?; let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?; let logits_v: Vec = prs.to_vec1()?; diff --git a/src/tensor.rs b/src/tensor.rs index 53a8b1f9..bfd9964f 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -24,6 +24,12 @@ pub struct Tensor_ { is_variable: bool, } +impl AsRef for Tensor { + fn as_ref(&self) -> &Tensor { + self + } +} + // Tensors are refcounted so that cloning is cheap when building the op graph. // Storages are also refcounted independently so that its possible to avoid // copying the storage for operations that only modify the shape or stride. @@ -802,19 +808,20 @@ impl Tensor { } } - pub fn cat(args: &[&Self], dim: usize) -> Result { + pub fn cat>(args: &[A], dim: usize) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } + let arg0 = args[0].as_ref(); if args.len() == 1 { - return Ok(args[0].clone()); + return Ok(arg0.clone()); } - let rank = args[0].rank(); + let rank = arg0.rank(); if dim >= rank { return Err(Error::UnexpectedNumberOfDims { expected: (dim + 1), got: rank, - shape: args[0].shape().clone(), + shape: arg0.shape().clone(), }); } if dim == 0 { @@ -824,29 +831,30 @@ impl Tensor { // for dim != 0... let args: Vec = args .iter() - .map(|a| a.transpose(0, dim)) + .map(|a| a.as_ref().transpose(0, dim)) .collect::>>()?; - let args: Vec<&Tensor> = args.iter().collect(); let cat = Self::cat0(&args)?; cat.transpose(0, dim) } } - pub fn cat0(args: &[&Self]) -> Result { + pub fn cat0>(args: &[A]) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } + let arg0 = args[0].as_ref(); if args.len() == 1 { - return Ok(args[0].clone()); + return Ok(arg0.clone()); } - let rank = args[0].rank(); - let device = args[0].device(); - let dtype = args[0].dtype(); - let first_dims = args[0].shape().dims(); + let rank = arg0.rank(); + let device = arg0.device(); + let dtype = arg0.dtype(); + let first_dims = arg0.shape().dims(); let mut cat_dims = first_dims.to_vec(); cat_dims[0] = 0; let mut offsets = vec![0usize]; for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); if arg.dtype() != dtype { // TODO: Improve the error message. return Err(Error::DTypeMismatchBinaryOp { @@ -864,7 +872,7 @@ impl Tensor { }); } let mut mismatch = arg.rank() != rank; - for (dim_idx, (v1, v2)) in args[0] + for (dim_idx, (v1, v2)) in arg0 .shape() .dims() .iter() @@ -883,7 +891,7 @@ impl Tensor { if mismatch { return Err(Error::ShapeMismatchCat { dim: 0, // TODO: not the appropriate error message - first_shape: args[0].shape().clone(), + first_shape: arg0.shape().clone(), n: arg_idx + 1, nth_shape: arg.shape().clone(), }); @@ -892,14 +900,15 @@ impl Tensor { offsets.push(next_offset); } let shape = Shape::from(cat_dims); - let op = if args.iter().any(|arg| arg.track_op()) { - let args: Vec = args.iter().map(|&arg| arg.clone()).collect(); + let op = if args.iter().any(|arg| arg.as_ref().track_op()) { + let args: Vec = args.iter().map(|arg| arg.as_ref().clone()).collect(); Some(Op::Cat(args, 0)) } else { None }; let mut storage = device.zeros(&shape, dtype)?; for (arg, &offset) in args.iter().zip(offsets.iter()) { + let arg = arg.as_ref(); arg.storage .copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?; }