mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Address comments.
This commit is contained in:
@ -128,8 +128,8 @@ impl Storage {
|
|||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.same_device(rhs, "matmul")?;
|
self.same_device(rhs, "embedding")?;
|
||||||
self.same_dtype(rhs, "matmul")?;
|
self.same_dtype(rhs, "embedding")?;
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
|
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
|
||||||
|
@ -360,12 +360,17 @@ impl Tensor {
|
|||||||
.storage
|
.storage
|
||||||
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?;
|
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?;
|
||||||
let shape: Shape = (seq_len, hidden_size).into();
|
let shape: Shape = (seq_len, hidden_size).into();
|
||||||
|
let op = if ids.track_op() || rhs.track_op() {
|
||||||
|
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage,
|
storage,
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
stride: shape.stride_contiguous(),
|
stride: shape.stride_contiguous(),
|
||||||
op: Some(Op::Embedding(ids.clone(), rhs.clone())),
|
op,
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
@ -859,7 +864,7 @@ impl Tensor {
|
|||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Embedding(_lhs, _rhs) => {
|
Op::Embedding(_lhs, _rhs) => {
|
||||||
todo!("Backward for embedding not implemented");
|
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||||
}
|
}
|
||||||
Op::Matmul(lhs, rhs) => {
|
Op::Matmul(lhs, rhs) => {
|
||||||
// Skipping checks, the op went ok, we can skip
|
// Skipping checks, the op went ok, we can skip
|
||||||
|
Reference in New Issue
Block a user