Allow from_vec/from_slice to use a ShapeWithOneHole as shape. (#2905)

This commit is contained in:
Laurent Mazare
2025-04-17 08:59:18 +02:00
committed by GitHub
parent 7f0f83a7c1
commit 9954981327

View File

@ -3,7 +3,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::shape::{Dim, Dims, ShapeWithOneHole};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -452,17 +452,13 @@ impl Tensor {
Self::from_vec_impl(data, len, device, false)
}
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
let buffer_size = data.len();
if buffer_size != shape.elem_count() {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let shape = shape.into_shape(data.len())?;
let storage = device.storage_owned(data)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
@ -481,7 +477,7 @@ impl Tensor {
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
@ -502,17 +498,12 @@ impl Tensor {
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
array: &[D],
shape: S,
device: &Device,
) -> Result<Self> {
let shape = shape.into();
let n: usize = shape.elem_count();
let buffer_size: usize = array.len();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let shape = shape.into_shape(array.len())?;
let storage = device.storage_from_slice(array)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, false))
@ -2197,7 +2188,7 @@ impl Tensor {
///
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {