Merge pull request #6 from LaurentMazare/add_embedding

Adding embedding op (not generic gather, no select).
This commit is contained in:
Nicolas Patry
2023-06-23 13:49:13 +02:00
committed by GitHub
8 changed files with 160 additions and 0 deletions

View File

@ -345,6 +345,38 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
if !rhs.is_contiguous() {
return Err(Error::RequiresContiguous { op: "embedding" });
} else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 {
return Err(Error::ShapeMismatchBinaryOp {
lhs: ids.shape.clone(),
rhs: rhs.shape.clone(),
op: "embedding",
});
}
let seq_len = ids.shape().r1()?;
let (vocab_size, hidden_size) = rhs.shape().r2()?;
let storage = ids
.storage
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
}
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::new(self.dims(), self.stride())
}
@ -741,6 +773,7 @@ impl Tensor {
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
@ -832,6 +865,9 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
}
Op::Matmul(lhs, rhs) => {
// Skipping checks, the op went ok, we can skip
// the matmul size checks for now.