Only narrow when needed + deactivate the kv cache.

This commit is contained in:
laurent
2023-06-29 19:07:52 +01:00
parent 3232df9458
commit b50bd880ce
3 changed files with 41 additions and 16 deletions

View File

@ -10,6 +10,14 @@ pub enum Error {
got: DType,
},
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
NarrowInvalidArgs {
shape: Shape,
dim: usize,
start: usize,
len: usize,
},
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },

View File

@ -349,21 +349,34 @@ impl Tensor {
}
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + length`.
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
let op = if self.track_op() {
Some(Op::Narrow(self.clone(), dim, start, length))
/// ranges from `start` to `start + len`.
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
if dim >= dims.len() || start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
start,
len,
})?
}
if start == 0 && dims[dim] == len {
Ok(self.clone())
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout().narrow(dim, start, length)?,
op,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
let op = if self.track_op() {
Some(Op::Narrow(self.clone(), dim, start, len))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout().narrow(dim, start, len)?,
op,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
}
}
pub fn softmax(&self, dim: usize) -> Result<Self> {