mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Merge pull request #6 from LaurentMazare/add_embedding
Adding embedding op (not generic gather, no select).
This commit is contained in:
@ -345,6 +345,38 @@ impl Tensor {
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||
if !rhs.is_contiguous() {
|
||||
return Err(Error::RequiresContiguous { op: "embedding" });
|
||||
} else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: ids.shape.clone(),
|
||||
rhs: rhs.shape.clone(),
|
||||
op: "embedding",
|
||||
});
|
||||
}
|
||||
let seq_len = ids.shape().r1()?;
|
||||
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids
|
||||
.storage
|
||||
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = if ids.track_op() || rhs.track_op() {
|
||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||
crate::StridedIndex::new(self.dims(), self.stride())
|
||||
}
|
||||
@ -741,6 +773,7 @@ impl Tensor {
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
@ -832,6 +865,9 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
}
|
||||
Op::Matmul(lhs, rhs) => {
|
||||
// Skipping checks, the op went ok, we can skip
|
||||
// the matmul size checks for now.
|
||||
|
Reference in New Issue
Block a user