mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Improve how we check that the dims are in bounds.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@ -15,3 +15,6 @@ Cargo.lock
|
||||
|
||||
*tokenizer.json
|
||||
*.npz
|
||||
|
||||
perf.data
|
||||
flamegraph.svg
|
||||
|
@ -10,6 +10,13 @@ pub enum Error {
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||
DimOutOfRange {
|
||||
shape: Shape,
|
||||
dim: usize,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
||||
NarrowInvalidArgs {
|
||||
shape: Shape,
|
||||
|
@ -348,11 +348,24 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
if dim >= self.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
op,
|
||||
})?
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + len`.
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.dims();
|
||||
if dim >= dims.len() || start + len > dims[dim] {
|
||||
self.check_dim(dim, "narrow")?;
|
||||
if start + len > dims[dim] {
|
||||
Err(Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
@ -380,6 +393,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||
self.check_dim(dim, "softmax")?;
|
||||
// TODO: unify the two branches.
|
||||
if self.device().is_cuda() {
|
||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||
@ -402,6 +416,9 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
for &dim in sum_dims {
|
||||
self.check_dim(dim, "sum")?;
|
||||
}
|
||||
let storage = self.storage.sum(self.layout(), sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
@ -645,6 +662,8 @@ impl Tensor {
|
||||
} else {
|
||||
let start_dim = start_dim.unwrap_or(0);
|
||||
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
|
||||
self.check_dim(start_dim, "flatten")?;
|
||||
self.check_dim(end_dim, "flatten")?;
|
||||
if start_dim < end_dim {
|
||||
let dims = self.dims();
|
||||
let mut dst_dims = dims[..start_dim].to_vec();
|
||||
@ -689,6 +708,8 @@ impl Tensor {
|
||||
/// Returns a tensor that is a transposed version of the input, the given dimensions are
|
||||
/// swapped.
|
||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
|
||||
self.check_dim(dim1, "transpose")?;
|
||||
self.check_dim(dim2, "transpose")?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Transpose(self.clone(), dim1, dim2))
|
||||
} else {
|
||||
@ -876,6 +897,7 @@ impl Tensor {
|
||||
// The PyTorch semantics are to return the same tensor if the target dimension
|
||||
// does not have a size of 1.
|
||||
let dims = self.dims();
|
||||
self.check_dim(index, "squeeze")?;
|
||||
if dims[index] == 1 {
|
||||
let mut dims = dims.to_vec();
|
||||
dims.remove(index);
|
||||
@ -910,13 +932,8 @@ impl Tensor {
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
if dim >= rank {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: (dim + 1),
|
||||
got: rank,
|
||||
shape: arg0.shape().clone(),
|
||||
});
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
|
Reference in New Issue
Block a user