mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Take as input slices of tensors as well as slices of &Tensors.
This commit is contained in:
@ -427,7 +427,7 @@ fn main() -> Result<()> {
|
|||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||||
let input = Tensor::new(ctxt, &Device::Cpu)?;
|
let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?;
|
||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
|
@ -24,6 +24,12 @@ pub struct Tensor_ {
|
|||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AsRef<Tensor> for Tensor {
|
||||||
|
fn as_ref(&self) -> &Tensor {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Tensors are refcounted so that cloning is cheap when building the op graph.
|
// Tensors are refcounted so that cloning is cheap when building the op graph.
|
||||||
// Storages are also refcounted independently so that its possible to avoid
|
// Storages are also refcounted independently so that its possible to avoid
|
||||||
// copying the storage for operations that only modify the shape or stride.
|
// copying the storage for operations that only modify the shape or stride.
|
||||||
@ -802,19 +808,20 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cat(args: &[&Self], dim: usize) -> Result<Self> {
|
pub fn cat<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||||
if args.is_empty() {
|
if args.is_empty() {
|
||||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||||
}
|
}
|
||||||
|
let arg0 = args[0].as_ref();
|
||||||
if args.len() == 1 {
|
if args.len() == 1 {
|
||||||
return Ok(args[0].clone());
|
return Ok(arg0.clone());
|
||||||
}
|
}
|
||||||
let rank = args[0].rank();
|
let rank = arg0.rank();
|
||||||
if dim >= rank {
|
if dim >= rank {
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
expected: (dim + 1),
|
expected: (dim + 1),
|
||||||
got: rank,
|
got: rank,
|
||||||
shape: args[0].shape().clone(),
|
shape: arg0.shape().clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if dim == 0 {
|
if dim == 0 {
|
||||||
@ -824,29 +831,30 @@ impl Tensor {
|
|||||||
// for dim != 0...
|
// for dim != 0...
|
||||||
let args: Vec<Tensor> = args
|
let args: Vec<Tensor> = args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|a| a.transpose(0, dim))
|
.map(|a| a.as_ref().transpose(0, dim))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let args: Vec<&Tensor> = args.iter().collect();
|
|
||||||
let cat = Self::cat0(&args)?;
|
let cat = Self::cat0(&args)?;
|
||||||
cat.transpose(0, dim)
|
cat.transpose(0, dim)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cat0(args: &[&Self]) -> Result<Self> {
|
pub fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||||
if args.is_empty() {
|
if args.is_empty() {
|
||||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||||
}
|
}
|
||||||
|
let arg0 = args[0].as_ref();
|
||||||
if args.len() == 1 {
|
if args.len() == 1 {
|
||||||
return Ok(args[0].clone());
|
return Ok(arg0.clone());
|
||||||
}
|
}
|
||||||
let rank = args[0].rank();
|
let rank = arg0.rank();
|
||||||
let device = args[0].device();
|
let device = arg0.device();
|
||||||
let dtype = args[0].dtype();
|
let dtype = arg0.dtype();
|
||||||
let first_dims = args[0].shape().dims();
|
let first_dims = arg0.shape().dims();
|
||||||
let mut cat_dims = first_dims.to_vec();
|
let mut cat_dims = first_dims.to_vec();
|
||||||
cat_dims[0] = 0;
|
cat_dims[0] = 0;
|
||||||
let mut offsets = vec![0usize];
|
let mut offsets = vec![0usize];
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
|
let arg = arg.as_ref();
|
||||||
if arg.dtype() != dtype {
|
if arg.dtype() != dtype {
|
||||||
// TODO: Improve the error message.
|
// TODO: Improve the error message.
|
||||||
return Err(Error::DTypeMismatchBinaryOp {
|
return Err(Error::DTypeMismatchBinaryOp {
|
||||||
@ -864,7 +872,7 @@ impl Tensor {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
let mut mismatch = arg.rank() != rank;
|
let mut mismatch = arg.rank() != rank;
|
||||||
for (dim_idx, (v1, v2)) in args[0]
|
for (dim_idx, (v1, v2)) in arg0
|
||||||
.shape()
|
.shape()
|
||||||
.dims()
|
.dims()
|
||||||
.iter()
|
.iter()
|
||||||
@ -883,7 +891,7 @@ impl Tensor {
|
|||||||
if mismatch {
|
if mismatch {
|
||||||
return Err(Error::ShapeMismatchCat {
|
return Err(Error::ShapeMismatchCat {
|
||||||
dim: 0, // TODO: not the appropriate error message
|
dim: 0, // TODO: not the appropriate error message
|
||||||
first_shape: args[0].shape().clone(),
|
first_shape: arg0.shape().clone(),
|
||||||
n: arg_idx + 1,
|
n: arg_idx + 1,
|
||||||
nth_shape: arg.shape().clone(),
|
nth_shape: arg.shape().clone(),
|
||||||
});
|
});
|
||||||
@ -892,14 +900,15 @@ impl Tensor {
|
|||||||
offsets.push(next_offset);
|
offsets.push(next_offset);
|
||||||
}
|
}
|
||||||
let shape = Shape::from(cat_dims);
|
let shape = Shape::from(cat_dims);
|
||||||
let op = if args.iter().any(|arg| arg.track_op()) {
|
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
|
||||||
let args: Vec<Tensor> = args.iter().map(|&arg| arg.clone()).collect();
|
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
|
||||||
Some(Op::Cat(args, 0))
|
Some(Op::Cat(args, 0))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let mut storage = device.zeros(&shape, dtype)?;
|
let mut storage = device.zeros(&shape, dtype)?;
|
||||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||||
|
let arg = arg.as_ref();
|
||||||
arg.storage
|
arg.storage
|
||||||
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?;
|
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user