Tensor mutability (#154)

* Working towards tensor mutability.

* Use a ref-cell to provide tensor mutability.
This commit is contained in:
Laurent Mazare
2023-07-13 11:04:40 +01:00
committed by GitHub
parent a3663ce2f2
commit 50b0946a2d
14 changed files with 124 additions and 88 deletions

View File

@ -311,13 +311,13 @@ impl MusicgenDecoder {
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
let b_sz = b_sz_times_codebooks / self.num_codebooks;
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, &dev)?;
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
for (idx, codebook) in self.embed_tokens.iter().enumerate() {
let inp = input.narrow(1, idx, 1)?.squeeze(1)?;
inputs_embeds = (inputs_embeds + codebook.forward(&inp)?)?
}
let inputs_embeds = inputs_embeds;
let positions = self.embed_positions.forward(&input)?.to_device(&dev)?;
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
let mut xs = inputs_embeds.broadcast_add(&positions)?;
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {