From 2fb87edda5a9fec6d31aefc445b242b1748f9bd2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Jun 2023 13:42:45 +0200 Subject: [PATCH] Address comments. --- src/storage.rs | 4 ++-- src/tensor.rs | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index c4938fa3..74ab7f40 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -128,8 +128,8 @@ impl Storage { hidden_size: usize, vocab_size: usize, ) -> Result { - self.same_device(rhs, "matmul")?; - self.same_dtype(rhs, "matmul")?; + self.same_device(rhs, "embedding")?; + self.same_dtype(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?; diff --git a/src/tensor.rs b/src/tensor.rs index 65a90d7e..c38e51cd 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -360,12 +360,17 @@ impl Tensor { .storage .embedding_impl(&rhs.storage, hidden_size, vocab_size)?; let shape: Shape = (seq_len, hidden_size).into(); + let op = if ids.track_op() || rhs.track_op() { + Some(Op::Embedding(ids.clone(), rhs.clone())) + } else { + None + }; let tensor_ = Tensor_ { id: TensorId::new(), storage, shape: shape.clone(), stride: shape.stride_contiguous(), - op: Some(Op::Embedding(ids.clone(), rhs.clone())), + op, is_variable: false, }; Ok(Self(Arc::new(tensor_))) @@ -859,7 +864,7 @@ impl Tensor { *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } Op::Embedding(_lhs, _rhs) => { - todo!("Backward for embedding not implemented"); + return Err(Error::BackwardNotSupported { op: "embedding" }) } Op::Matmul(lhs, rhs) => { // Skipping checks, the op went ok, we can skip