mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Skeleton implementation for the narrow method and op.
This commit is contained in:
@ -27,9 +27,9 @@ pub(crate) enum Op {
|
||||
Sin(Tensor),
|
||||
Cos(Tensor),
|
||||
Abs(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
Neg(Tensor),
|
||||
Reshape(Tensor),
|
||||
#[allow(dead_code)]
|
||||
Softmax(Tensor, usize),
|
||||
Sqr(Tensor),
|
||||
Sqrt(Tensor),
|
||||
|
@ -298,6 +298,36 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + length`.
|
||||
// TODO: Once we've refactor the shape and strides, make this return a view of the same data
|
||||
// rather than copying.
|
||||
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||
let dims = self.shape().dims();
|
||||
if dim >= dims.len() {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: dim + 1,
|
||||
got: dims.len(),
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
if start + length > dims[dim] {
|
||||
todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}")
|
||||
}
|
||||
let mut dims = dims.to_vec();
|
||||
dims[dim] = length;
|
||||
let shape = Shape::from(dims);
|
||||
let storage = self.device().zeros(&shape, self.dtype())?;
|
||||
// TODO: Actually copy the data, compared to copy_strided_src this requires a src start
|
||||
// offset as well as a way to specify the number of elements to be copied.
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Narrow(self.clone(), dim, start, length))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let mut storage = self
|
||||
@ -817,6 +847,7 @@ impl Tensor {
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
| Op::Sqr(node)
|
||||
| Op::Sqrt(node)
|
||||
@ -933,7 +964,10 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
|
||||
Op::Cat(_args, _dim) => {
|
||||
// TODO: Use narrow here.
|
||||
return Err(Error::BackwardNotSupported { op: "cat" });
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
@ -964,6 +998,9 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
Op::Narrow(_arg, _, _, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "narrow" })
|
||||
}
|
||||
Op::Softmax(_arg, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "softmax" })
|
||||
}
|
||||
|
Reference in New Issue
Block a user