mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Remove the embedding ops in favor of index-select. (#299)
* Remove the embedding ops in favor of index-select. * Also remove the cuda kernels.
This commit is contained in:
@ -534,7 +534,7 @@ fn cat(device: &Device) -> Result<()> {
|
||||
fn embeddings(device: &Device) -> Result<()> {
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
|
||||
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||
let hs = Tensor::embedding(&ids, &t)?;
|
||||
let hs = t.embedding(&ids)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
|
Reference in New Issue
Block a user