mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Tensor mutability (#154)
* Working towards tensor mutability. * Use a ref-cell to provide tensor mutability.
This commit is contained in:
@ -196,7 +196,7 @@ impl BertEmbeddings {
|
||||
if let Some(position_embeddings) = &self.position_embeddings {
|
||||
// TODO: Proper absolute positions?
|
||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||
let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?;
|
||||
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
||||
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
|
||||
}
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
|
@ -183,7 +183,7 @@ impl FalconRotaryEmbedding {
|
||||
past_kv_len: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, &query.device(), query.dtype())?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
|
||||
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
||||
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
||||
let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
|
||||
@ -194,7 +194,7 @@ impl FalconRotaryEmbedding {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
@ -471,7 +471,7 @@ impl Falcon {
|
||||
Some((k, _)) => k.dim(1)?,
|
||||
None => 0,
|
||||
};
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(&input_ids.device())?;
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
|
||||
}
|
||||
|
@ -227,7 +227,7 @@ impl CausalSelfAttention {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
@ -180,7 +180,7 @@ impl EncodecResidualVectorQuantizer {
|
||||
}
|
||||
|
||||
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, &codes.device())?;
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||
if codes.dim(0)? != self.layers.len() {
|
||||
anyhow::bail!(
|
||||
"codes shape {:?} does not match the number of quantization layers {}",
|
||||
|
@ -311,13 +311,13 @@ impl MusicgenDecoder {
|
||||
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
|
||||
let b_sz = b_sz_times_codebooks / self.num_codebooks;
|
||||
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
|
||||
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, &dev)?;
|
||||
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
|
||||
for (idx, codebook) in self.embed_tokens.iter().enumerate() {
|
||||
let inp = input.narrow(1, idx, 1)?.squeeze(1)?;
|
||||
inputs_embeds = (inputs_embeds + codebook.forward(&inp)?)?
|
||||
}
|
||||
let inputs_embeds = inputs_embeds;
|
||||
let positions = self.embed_positions.forward(&input)?.to_device(&dev)?;
|
||||
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
|
||||
let mut xs = inputs_embeds.broadcast_add(&positions)?;
|
||||
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
|
||||
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
|
||||
|
@ -109,7 +109,7 @@ impl Decoder {
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![SOT_TOKEN];
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
|
Reference in New Issue
Block a user