mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Move the weight bits in a separate module. (#295)
This commit is contained in:
@ -106,7 +106,6 @@ struct CausalSelfAttention {
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
max_seq_len: usize,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
@ -198,7 +197,6 @@ impl CausalSelfAttention {
|
||||
n_key_value_head: cfg.n_kv_heads,
|
||||
head_dim: cfg.dim / cfg.n_heads,
|
||||
cache: cache.clone(),
|
||||
max_seq_len: cfg.seq_len,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -291,7 +289,7 @@ pub struct Llama {
|
||||
|
||||
impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
|
Reference in New Issue
Block a user