Tensor mutability (#154)

* Working towards tensor mutability.

* Use a ref-cell to provide tensor mutability.
This commit is contained in:
Laurent Mazare
2023-07-13 11:04:40 +01:00
committed by GitHub
parent a3663ce2f2
commit 50b0946a2d
14 changed files with 124 additions and 88 deletions

View File

@ -196,7 +196,7 @@ impl BertEmbeddings {
if let Some(position_embeddings) = &self.position_embeddings {
// TODO: Proper absolute positions?
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?;
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
}
let embeddings = self.layer_norm.forward(&embeddings)?;