mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add the stable-lm example. (#1046)
* Add the stable-lm example. * Get stable-lm to generate some proper text.
This commit is contained in:
@ -148,6 +148,7 @@ struct Attention {
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
use_cache: bool,
|
||||
rotary_ndims: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
@ -173,6 +174,7 @@ impl Attention {
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
use_cache: cfg.use_cache,
|
||||
rotary_ndims: cfg.rotary_ndims(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -210,9 +212,16 @@ impl Attention {
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
let (rot_ndims, pass_ndims) = (self.rotary_ndims, self.head_dim - self.rotary_ndims);
|
||||
let query_rot = query_states.narrow(D::Minus1, 0, rot_ndims)?;
|
||||
let query_pass = query_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
|
||||
let key_rot = key_states.narrow(D::Minus1, 0, rot_ndims)?;
|
||||
let key_pass = key_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
|
||||
let (query_rot, key_rot) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
.apply_rotary_emb_qkv(&query_rot, &key_rot, seqlen_offset)?;
|
||||
let query_states = Tensor::cat(&[query_rot, query_pass], D::Minus1)?.contiguous()?;
|
||||
let key_states = Tensor::cat(&[key_rot, key_pass], D::Minus1)?.contiguous()?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
@ -226,8 +235,8 @@ impl Attention {
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
}
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
|
Reference in New Issue
Block a user