mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Validate the kernel size in pooling ops. (#1473)
* Validate the kernel size in pooling ops. * Revert the changes to basics.
This commit is contained in:
@ -396,7 +396,7 @@ impl Tensor {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
if D::is_zero(&step) {
|
if D::is_zero(&step) {
|
||||||
crate::bail!("step cannot be zero")
|
bail!("step cannot be zero")
|
||||||
}
|
}
|
||||||
let mut data = vec![];
|
let mut data = vec![];
|
||||||
let mut current = start;
|
let mut current = start;
|
||||||
@ -1041,6 +1041,9 @@ impl Tensor {
|
|||||||
let kernel_size = kernel_size.to_usize2();
|
let kernel_size = kernel_size.to_usize2();
|
||||||
let stride = stride.to_usize2();
|
let stride = stride.to_usize2();
|
||||||
let (n, c, h, w) = self.dims4()?;
|
let (n, c, h, w) = self.dims4()?;
|
||||||
|
if h < kernel_size.0 || w < kernel_size.1 {
|
||||||
|
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
||||||
|
}
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||||
@ -1076,6 +1079,9 @@ impl Tensor {
|
|||||||
let kernel_size = kernel_size.to_usize2();
|
let kernel_size = kernel_size.to_usize2();
|
||||||
let stride = stride.to_usize2();
|
let stride = stride.to_usize2();
|
||||||
let (n, c, h, w) = self.dims4()?;
|
let (n, c, h, w) = self.dims4()?;
|
||||||
|
if h < kernel_size.0 || w < kernel_size.1 {
|
||||||
|
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
||||||
|
}
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||||
@ -1798,7 +1804,7 @@ impl Tensor {
|
|||||||
let is_permutation =
|
let is_permutation =
|
||||||
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
||||||
if !is_permutation {
|
if !is_permutation {
|
||||||
crate::bail!(
|
bail!(
|
||||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||||
self.dims(),
|
self.dims(),
|
||||||
dims
|
dims
|
||||||
@ -2293,7 +2299,7 @@ impl Tensor {
|
|||||||
if left == 0 && right == 0 {
|
if left == 0 && right == 0 {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
} else if self.elem_count() == 0 {
|
} else if self.elem_count() == 0 {
|
||||||
crate::bail!("cannot use pad_with_same on an empty tensor")
|
bail!("cannot use pad_with_same on an empty tensor")
|
||||||
} else if left == 0 {
|
} else if left == 0 {
|
||||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||||
@ -2457,13 +2463,13 @@ impl Tensor {
|
|||||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||||
let rank = self.rank() as i64;
|
let rank = self.rank() as i64;
|
||||||
if rank <= axis {
|
if rank <= axis {
|
||||||
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
bail!("axis {axis} is too large, tensor rank {rank}")
|
||||||
} else if 0 <= axis {
|
} else if 0 <= axis {
|
||||||
Ok(axis as usize)
|
Ok(axis as usize)
|
||||||
} else {
|
} else {
|
||||||
let naxis = rank + axis;
|
let naxis = rank + axis;
|
||||||
if naxis < 0 {
|
if naxis < 0 {
|
||||||
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
bail!("axis {axis} is too small, tensor rank {rank}")
|
||||||
}
|
}
|
||||||
Ok(naxis as usize)
|
Ok(naxis as usize)
|
||||||
}
|
}
|
||||||
@ -2525,14 +2531,14 @@ impl Tensor {
|
|||||||
let src_dims = src.dims();
|
let src_dims = src.dims();
|
||||||
let self_dims = self.dims();
|
let self_dims = self.dims();
|
||||||
if self_dims.len() != src_dims.len() {
|
if self_dims.len() != src_dims.len() {
|
||||||
crate::bail!(
|
bail!(
|
||||||
"slice-assign requires input with the same rank {} <> {}",
|
"slice-assign requires input with the same rank {} <> {}",
|
||||||
self_dims.len(),
|
self_dims.len(),
|
||||||
src_dims.len()
|
src_dims.len()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if self_dims.len() != ranges.len() {
|
if self_dims.len() != ranges.len() {
|
||||||
crate::bail!(
|
bail!(
|
||||||
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
||||||
self_dims.len(),
|
self_dims.len(),
|
||||||
ranges.len()
|
ranges.len()
|
||||||
@ -2552,18 +2558,16 @@ impl Tensor {
|
|||||||
std::ops::Bound::Excluded(v) => *v,
|
std::ops::Bound::Excluded(v) => *v,
|
||||||
};
|
};
|
||||||
if end_excluded <= start_included {
|
if end_excluded <= start_included {
|
||||||
crate::bail!(
|
bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
|
||||||
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
if self_dims[i] < end_excluded {
|
if self_dims[i] < end_excluded {
|
||||||
crate::bail!(
|
bail!(
|
||||||
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
||||||
self_dims[i]
|
self_dims[i]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if end_excluded - start_included != src_dims[i] {
|
if end_excluded - start_included != src_dims[i] {
|
||||||
crate::bail!(
|
bail!(
|
||||||
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user