Move the weight bits in a separate module. (#295)

This commit is contained in:
Laurent Mazare
2023-08-01 10:37:06 +01:00
committed by GitHub
parent 614f911e9e
commit 75e0448114
3 changed files with 168 additions and 164 deletions

View File

@ -106,7 +106,6 @@ struct CausalSelfAttention {
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
max_seq_len: usize,
}
impl CausalSelfAttention {
@ -198,7 +197,6 @@ impl CausalSelfAttention {
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
max_seq_len: cfg.seq_len,
})
}
}
@ -291,7 +289,7 @@ pub struct Llama {
impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;