mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Do not use the kv-cache on external key-value states. (#1054)
This commit is contained in:
@ -153,7 +153,6 @@ fn main() -> Result<()> {
|
|||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let _guard = if args.tracing {
|
let _guard = if args.tracing {
|
||||||
println!("tracing...");
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
Some(guard)
|
Some(guard)
|
||||||
|
@ -348,21 +348,21 @@ impl T5Attention {
|
|||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
let mut k = k
|
let mut k = k
|
||||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?;
|
||||||
.contiguous()?;
|
|
||||||
let mut v = v
|
let mut v = v
|
||||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?;
|
||||||
.contiguous()?;
|
|
||||||
|
|
||||||
if self.use_cache {
|
if self.use_cache && key_value_states.is_none() {
|
||||||
let _enter = self.span_cache.enter();
|
let _enter = self.span_cache.enter();
|
||||||
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
|
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
|
||||||
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
|
k = Tensor::cat(&[kv_cache_k, &k], 2)?;
|
||||||
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
|
v = Tensor::cat(&[kv_cache_v, &v], 2)?;
|
||||||
};
|
};
|
||||||
self.kv_cache = Some((k.clone(), v.clone()));
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
};
|
};
|
||||||
|
let k = k.contiguous()?;
|
||||||
|
let v = v.contiguous()?;
|
||||||
// TODO: Use flash_attn.
|
// TODO: Use flash_attn.
|
||||||
let scores = {
|
let scores = {
|
||||||
let _enter = self.span_mm.enter();
|
let _enter = self.span_mm.enter();
|
||||||
|
@ -348,21 +348,21 @@ impl T5Attention {
|
|||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
let mut k = k
|
let mut k = k
|
||||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?;
|
||||||
.contiguous()?;
|
|
||||||
let mut v = v
|
let mut v = v
|
||||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?;
|
||||||
.contiguous()?;
|
|
||||||
|
|
||||||
if self.use_cache {
|
if self.use_cache && key_value_states.is_none() {
|
||||||
let _enter = self.span_cache.enter();
|
let _enter = self.span_cache.enter();
|
||||||
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
|
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
|
||||||
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
|
k = Tensor::cat(&[kv_cache_k, &k], 2)?;
|
||||||
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
|
v = Tensor::cat(&[kv_cache_v, &v], 2)?;
|
||||||
};
|
};
|
||||||
self.kv_cache = Some((k.clone(), v.clone()));
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
};
|
};
|
||||||
|
let k = k.contiguous()?;
|
||||||
|
let v = v.contiguous()?;
|
||||||
// TODO: Use flash_attn.
|
// TODO: Use flash_attn.
|
||||||
let scores = {
|
let scores = {
|
||||||
let _enter = self.span_mm.enter();
|
let _enter = self.span_mm.enter();
|
||||||
|
Reference in New Issue
Block a user