mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the index-select op. (#209)
* Add the index-select op. * Cpu implementation of index-select. * Add the cpu implementation for index-select.
This commit is contained in:
@ -40,6 +40,7 @@ impl Tensor {
|
||||
..
|
||||
}
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
@ -143,9 +144,12 @@ impl Tensor {
|
||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::IndexSelect(_lhs, _rhs, _) => {
|
||||
Err(Error::BackwardNotSupported { op: "index-select" })?
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
Err(Error::BackwardNotSupported { op: "embedding" })?
|
||||
}
|
||||
Op::Matmul(lhs, rhs) => {
|
||||
// Skipping checks, the op went ok, we can skip
|
||||
@ -195,10 +199,10 @@ impl Tensor {
|
||||
}
|
||||
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
||||
Op::Reduce(_args, ReduceOp::Max, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "max" })
|
||||
Err(Error::BackwardNotSupported { op: "max" })?
|
||||
}
|
||||
Op::Reduce(_args, ReduceOp::Min, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "min" })
|
||||
Err(Error::BackwardNotSupported { op: "min" })?
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -221,9 +225,7 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Abs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "abs" })
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
|
||||
Op::Unary(arg, UnaryOp::Exp) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||
@ -258,21 +260,15 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Softmax(_arg, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "softmax" })
|
||||
}
|
||||
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => {
|
||||
return Err(Error::BackwardNotSupported { op: "gelu" })
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Relu) => {
|
||||
return Err(Error::BackwardNotSupported { op: "relu" })
|
||||
}
|
||||
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user