Add a couple functions required for yolo. (#527)

This commit is contained in:
Laurent Mazare
2023-08-20 17:02:05 +01:00
committed by GitHub
parent 372f8912c5
commit e3d2786ffb
5 changed files with 69 additions and 1 deletions

View File

@ -487,6 +487,28 @@ impl Tensor {
self.to_scalar::<S>()
}
/// Repeat this tensor along the specified dimensions.
pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
// Similar to PyTorch, we extend the number of dimensions of self if needed.
let repeats = shape.into();
let repeats = repeats.dims();
let mut inp = if self.rank() < repeats.len() {
let mut shape = self.dims().to_vec();
while shape.len() < repeats.len() {
shape.push(1)
}
self.reshape(shape)?
} else {
self.clone()
};
for (idx, &repeat) in repeats.iter().enumerate() {
if repeat > 1 {
inp = Tensor::cat(&vec![&inp; repeat], idx)?
}
}
Ok(inp)
}
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
/// be performed.