mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Skeleton implementation for the narrow method and op.
This commit is contained in:
@ -27,9 +27,9 @@ pub(crate) enum Op {
|
|||||||
Sin(Tensor),
|
Sin(Tensor),
|
||||||
Cos(Tensor),
|
Cos(Tensor),
|
||||||
Abs(Tensor),
|
Abs(Tensor),
|
||||||
|
Narrow(Tensor, usize, usize, usize),
|
||||||
Neg(Tensor),
|
Neg(Tensor),
|
||||||
Reshape(Tensor),
|
Reshape(Tensor),
|
||||||
#[allow(dead_code)]
|
|
||||||
Softmax(Tensor, usize),
|
Softmax(Tensor, usize),
|
||||||
Sqr(Tensor),
|
Sqr(Tensor),
|
||||||
Sqrt(Tensor),
|
Sqrt(Tensor),
|
||||||
|
@ -298,6 +298,36 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||||
|
/// ranges from `start` to `start + length`.
|
||||||
|
// TODO: Once we've refactor the shape and strides, make this return a view of the same data
|
||||||
|
// rather than copying.
|
||||||
|
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||||
|
let dims = self.shape().dims();
|
||||||
|
if dim >= dims.len() {
|
||||||
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: dim + 1,
|
||||||
|
got: dims.len(),
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if start + length > dims[dim] {
|
||||||
|
todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}")
|
||||||
|
}
|
||||||
|
let mut dims = dims.to_vec();
|
||||||
|
dims[dim] = length;
|
||||||
|
let shape = Shape::from(dims);
|
||||||
|
let storage = self.device().zeros(&shape, self.dtype())?;
|
||||||
|
// TODO: Actually copy the data, compared to copy_strided_src this requires a src start
|
||||||
|
// offset as well as a way to specify the number of elements to be copied.
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Narrow(self.clone(), dim, start, length))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let mut storage = self
|
let mut storage = self
|
||||||
@ -817,6 +847,7 @@ impl Tensor {
|
|||||||
| Op::ToDType(node)
|
| Op::ToDType(node)
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
|
| Op::Narrow(node, _, _, _)
|
||||||
| Op::Softmax(node, _)
|
| Op::Softmax(node, _)
|
||||||
| Op::Sqr(node)
|
| Op::Sqr(node)
|
||||||
| Op::Sqrt(node)
|
| Op::Sqrt(node)
|
||||||
@ -933,7 +964,10 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
|
Op::Cat(_args, _dim) => {
|
||||||
|
// TODO: Use narrow here.
|
||||||
|
return Err(Error::BackwardNotSupported { op: "cat" });
|
||||||
|
}
|
||||||
Op::ToDType(arg) => {
|
Op::ToDType(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||||
@ -964,6 +998,9 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.sub(&grad)?
|
*sum_grad = sum_grad.sub(&grad)?
|
||||||
}
|
}
|
||||||
|
Op::Narrow(_arg, _, _, _) => {
|
||||||
|
return Err(Error::BackwardNotSupported { op: "narrow" })
|
||||||
|
}
|
||||||
Op::Softmax(_arg, _) => {
|
Op::Softmax(_arg, _) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "softmax" })
|
return Err(Error::BackwardNotSupported { op: "softmax" })
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user