mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Tensor mutability (#154)
* Working towards tensor mutability. * Use a ref-cell to provide tensor mutability.
This commit is contained in:
@ -183,7 +183,7 @@ impl FalconRotaryEmbedding {
|
||||
past_kv_len: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, &query.device(), query.dtype())?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
|
||||
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
||||
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
||||
let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
|
||||
@ -194,7 +194,7 @@ impl FalconRotaryEmbedding {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
@ -471,7 +471,7 @@ impl Falcon {
|
||||
Some((k, _)) => k.dim(1)?,
|
||||
None => 0,
|
||||
};
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(&input_ids.device())?;
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
|
||||
}
|
||||
|
Reference in New Issue
Block a user