diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 38336ecf..a8702df7 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -536,6 +536,8 @@ fn embeddings(device: &Device) -> Result<()> { let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?; let hs = Tensor::embedding(&ids, &t)?; assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); + let hs = t.index_select(&ids, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); Ok(()) } diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index 050123be..f4ba88e7 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -23,7 +23,7 @@ impl Embedding { let mut final_dims = indexes.dims().to_vec(); final_dims.push(self.hidden_size); let indexes = indexes.flatten_all()?; - let values = Tensor::embedding(&indexes, &self.embeddings)?; + let values = self.embeddings.index_select(&indexes, 0)?; let values = values.reshape(final_dims)?; Ok(values) }