mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add a couple functions required for yolo. (#527)
This commit is contained in:
@ -487,6 +487,28 @@ impl Tensor {
|
||||
self.to_scalar::<S>()
|
||||
}
|
||||
|
||||
/// Repeat this tensor along the specified dimensions.
|
||||
pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||
// Similar to PyTorch, we extend the number of dimensions of self if needed.
|
||||
let repeats = shape.into();
|
||||
let repeats = repeats.dims();
|
||||
let mut inp = if self.rank() < repeats.len() {
|
||||
let mut shape = self.dims().to_vec();
|
||||
while shape.len() < repeats.len() {
|
||||
shape.push(1)
|
||||
}
|
||||
self.reshape(shape)?
|
||||
} else {
|
||||
self.clone()
|
||||
};
|
||||
for (idx, &repeat) in repeats.iter().enumerate() {
|
||||
if repeat > 1 {
|
||||
inp = Tensor::cat(&vec![&inp; repeat], idx)?
|
||||
}
|
||||
}
|
||||
Ok(inp)
|
||||
}
|
||||
|
||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
||||
/// be performed.
|
||||
|
Reference in New Issue
Block a user