Adapt the cuda bits.

This commit is contained in:
laurent
2023-06-28 15:43:03 +01:00
parent cca699be6c
commit 3f0d9fbb25
5 changed files with 109 additions and 87 deletions

View File

@ -481,10 +481,10 @@ impl Tensor {
}
let ids_shape = ids.shape();
let seq_len = ids_shape.r1()?;
let (vocab_size, hidden_size) = rhs.shape().r2()?;
let (_, hidden_size) = rhs.shape().r2()?;
let storage = ids
.storage
.embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?;
.embedding(ids.layout(), &rhs.storage, rhs.layout())?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))