mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Remove some unnecessary calls to contiguous. (#1968)
* Remove some unnecessary calls to contiguous. * Slightly improved kv cache concatenation.
This commit is contained in:
@ -58,20 +58,18 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||
if all_contiguous {
|
||||
Self::cat_contiguous(args, dim)
|
||||
} else if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||
if all_contiguous {
|
||||
Self::cat_contiguous(args, dim)
|
||||
} else {
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -157,6 +157,8 @@ impl LayerWeights {
|
||||
let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
// The call to contiguous below is only necessary when processing the prompt.
|
||||
// When the seq_len is 1 in the inference loop, this is a no-op.
|
||||
candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin)
|
||||
}
|
||||
|
||||
@ -180,7 +182,11 @@ impl LayerWeights {
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
// This call to contiguous ensures that the fast kernel can be called below. It's
|
||||
// actually a no-op except when processing the initial prompt so has no significant
|
||||
// impact on performance.
|
||||
.contiguous()?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
@ -191,8 +197,8 @@ impl LayerWeights {
|
||||
if index_pos == 0 {
|
||||
(k, v)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
@ -486,7 +492,7 @@ impl ModelWeights {
|
||||
layer_in = x
|
||||
}
|
||||
let x = self.norm.forward(&layer_in)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
let _enter = self.span_output.enter();
|
||||
self.output.forward(&x)
|
||||
}
|
||||
|
Reference in New Issue
Block a user