fixes slice_scatter dim type (#988)

This commit is contained in:
Gonzalo
2023-09-29 03:54:45 -03:00
committed by GitHub
parent 53510ce427
commit 01b92cd959

View File

@ -1133,7 +1133,7 @@ impl Tensor {
} }
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: usize, start: usize) -> Result<Self> { pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
let dim = dim.to_index(self.shape(), "slice-scatter")?; let dim = dim.to_index(self.shape(), "slice-scatter")?;
if dim == 0 { if dim == 0 {
self.slice_scatter0(src, start) self.slice_scatter0(src, start)