mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Adapt the cuda bits.
This commit is contained in:
@ -481,10 +481,10 @@ impl Tensor {
|
||||
}
|
||||
let ids_shape = ids.shape();
|
||||
let seq_len = ids_shape.r1()?;
|
||||
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||
let (_, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids
|
||||
.storage
|
||||
.embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?;
|
||||
.embedding(ids.layout(), &rhs.storage, rhs.layout())?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = if ids.track_op() || rhs.track_op() {
|
||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||
|
Reference in New Issue
Block a user