Add the stable-lm example. (#1046)

* Add the stable-lm example.

* Get stable-lm to generate some proper text.
This commit is contained in:
Laurent Mazare
2023-10-06 19:20:35 +01:00
committed by GitHub
parent 904bbdae65
commit d5f7267087
2 changed files with 263 additions and 4 deletions

View File

@ -148,6 +148,7 @@ struct Attention {
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
use_cache: bool,
rotary_ndims: usize,
}
impl Attention {
@ -173,6 +174,7 @@ impl Attention {
rotary_emb,
kv_cache: None,
use_cache: cfg.use_cache,
rotary_ndims: cfg.rotary_ndims(),
})
}
@ -210,9 +212,16 @@ impl Attention {
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
let (rot_ndims, pass_ndims) = (self.rotary_ndims, self.head_dim - self.rotary_ndims);
let query_rot = query_states.narrow(D::Minus1, 0, rot_ndims)?;
let query_pass = query_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
let key_rot = key_states.narrow(D::Minus1, 0, rot_ndims)?;
let key_pass = key_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
let (query_rot, key_rot) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
.apply_rotary_emb_qkv(&query_rot, &key_rot, seqlen_offset)?;
let query_states = Tensor::cat(&[query_rot, query_pass], D::Minus1)?.contiguous()?;
let key_states = Tensor::cat(&[key_rot, key_pass], D::Minus1)?.contiguous()?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
@ -226,8 +235,8 @@ impl Attention {
self.kv_cache = Some((key_states.clone(), value_states.clone()));
}
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);