mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Only narrow when needed + deactivate the kv cache.
This commit is contained in:
@ -24,7 +24,7 @@ mod var_store;
|
|||||||
mod weights;
|
mod weights;
|
||||||
|
|
||||||
const CONTEXT_SIZE: usize = 512;
|
const CONTEXT_SIZE: usize = 512;
|
||||||
const USE_KV_CACHE: bool = true;
|
const USE_KV_CACHE: bool = false;
|
||||||
const START_PROMPT: &str = r"
|
const START_PROMPT: &str = r"
|
||||||
EDWARD:
|
EDWARD:
|
||||||
I wonder how our princely father 'scaped,
|
I wonder how our princely father 'scaped,
|
||||||
@ -268,7 +268,11 @@ impl CausalSelfAttention {
|
|||||||
|
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
let mut dims = x.dims().to_vec();
|
let mut dims = x.dims().to_vec();
|
||||||
let freqs_cis = freqs_cis.narrow(1, freqs_cis.dims()[1] - dims[1], dims[1])?;
|
let freqs_cis = if dims[1] < CONTEXT_SIZE {
|
||||||
|
freqs_cis.narrow(1, CONTEXT_SIZE - dims[1], dims[1])?
|
||||||
|
} else {
|
||||||
|
freqs_cis.clone()
|
||||||
|
};
|
||||||
let v = dims.pop().unwrap();
|
let v = dims.pop().unwrap();
|
||||||
dims.push(v / 2);
|
dims.push(v / 2);
|
||||||
dims.push(2);
|
dims.push(2);
|
||||||
|
@ -10,6 +10,14 @@ pub enum Error {
|
|||||||
got: DType,
|
got: DType,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
||||||
|
NarrowInvalidArgs {
|
||||||
|
shape: Shape,
|
||||||
|
dim: usize,
|
||||||
|
start: usize,
|
||||||
|
len: usize,
|
||||||
|
},
|
||||||
|
|
||||||
#[error("{op} only supports contiguous tensors")]
|
#[error("{op} only supports contiguous tensors")]
|
||||||
RequiresContiguous { op: &'static str },
|
RequiresContiguous { op: &'static str },
|
||||||
|
|
||||||
|
@ -349,21 +349,34 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||||
/// ranges from `start` to `start + length`.
|
/// ranges from `start` to `start + len`.
|
||||||
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||||
let op = if self.track_op() {
|
let dims = self.dims();
|
||||||
Some(Op::Narrow(self.clone(), dim, start, length))
|
if dim >= dims.len() || start + len > dims[dim] {
|
||||||
|
Err(Error::NarrowInvalidArgs {
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
dim,
|
||||||
|
start,
|
||||||
|
len,
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
if start == 0 && dims[dim] == len {
|
||||||
|
Ok(self.clone())
|
||||||
} else {
|
} else {
|
||||||
None
|
let op = if self.track_op() {
|
||||||
};
|
Some(Op::Narrow(self.clone(), dim, start, len))
|
||||||
let tensor_ = Tensor_ {
|
} else {
|
||||||
id: TensorId::new(),
|
None
|
||||||
storage: self.storage.clone(),
|
};
|
||||||
layout: self.layout().narrow(dim, start, length)?,
|
let tensor_ = Tensor_ {
|
||||||
op,
|
id: TensorId::new(),
|
||||||
is_variable: false,
|
storage: self.storage.clone(),
|
||||||
};
|
layout: self.layout().narrow(dim, start, len)?,
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
op,
|
||||||
|
is_variable: false,
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||||
|
Reference in New Issue
Block a user