mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Only narrow when needed + deactivate the kv cache.
This commit is contained in:
@ -10,6 +10,14 @@ pub enum Error {
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
||||
NarrowInvalidArgs {
|
||||
shape: Shape,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
len: usize,
|
||||
},
|
||||
|
||||
#[error("{op} only supports contiguous tensors")]
|
||||
RequiresContiguous { op: &'static str },
|
||||
|
||||
|
@ -349,21 +349,34 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + length`.
|
||||
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Narrow(self.clone(), dim, start, length))
|
||||
/// ranges from `start` to `start + len`.
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.dims();
|
||||
if dim >= dims.len() || start + len > dims[dim] {
|
||||
Err(Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
start,
|
||||
len,
|
||||
})?
|
||||
}
|
||||
if start == 0 && dims[dim] == len {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout().narrow(dim, start, length)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Narrow(self.clone(), dim, start, len))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout().narrow(dim, start, len)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||
|
Reference in New Issue
Block a user