Use index-select for the embeddings as it supports backprop. (#298)

This commit is contained in:
Laurent Mazare
2023-08-01 20:44:43 +01:00
committed by GitHub
parent ff876c2103
commit cc76c63202
2 changed files with 3 additions and 1 deletions

View File

@ -23,7 +23,7 @@ impl Embedding {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
let values = Tensor::embedding(&indexes, &self.embeddings)?;
let values = self.embeddings.index_select(&indexes, 0)?;
let values = values.reshape(final_dims)?;
Ok(values)
}