Address comments.

This commit is contained in:
Nicolas Patry
2023-06-23 13:42:45 +02:00
parent 96289bce08
commit 2fb87edda5
2 changed files with 9 additions and 4 deletions

View File

@ -128,8 +128,8 @@ impl Storage {
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
self.same_device(rhs, "matmul")?;
self.same_dtype(rhs, "matmul")?;
self.same_device(rhs, "embedding")?;
self.same_dtype(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;

View File

@ -360,12 +360,17 @@ impl Tensor {
.storage
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op: Some(Op::Embedding(ids.clone(), rhs.clone())),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
@ -859,7 +864,7 @@ impl Tensor {
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Embedding(_lhs, _rhs) => {
todo!("Backward for embedding not implemented");
return Err(Error::BackwardNotSupported { op: "embedding" })
}
Op::Matmul(lhs, rhs) => {
// Skipping checks, the op went ok, we can skip