Optimize the cat operation on contiguous tensors (#1855)

* Add a specialized kernel for copy2d.

* Move the cat operations.

* Avoid transpositions in cat.

* Bugfix.

* Bugfix for the cuda kernel.

* Add a benchmark.

* Add more testing.

* Test fix.

* Faster kernel.

* Add the missing kernel.

* Tweak the test.

* Add a metal kernel.

* Fix for the metal kernel.

* Get the tests to pass on metal.

* Also use this opportunity to fix the metal kernel for ELU.

* Add some bf16 kernels.

* Clippy fixes.
This commit is contained in:
Laurent Mazare
2024-03-17 10:49:13 +01:00
committed by GitHub
parent db8b24ae92
commit ce9fbc3682
19 changed files with 744 additions and 208 deletions

View File

@ -98,6 +98,19 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
#[allow(clippy::too_many_arguments)]
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
fn copy2d(
&self,
_: &mut Self,
_d1: usize,
_d2: usize,
_src_stride1: usize,
_dst_stride1: usize,
_src_offset: usize,
_dst_offset: usize,
) -> Result<()>;
}
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {

View File

@ -1023,6 +1023,26 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
}
}
#[allow(clippy::too_many_arguments)]
fn copy2d_<T: Copy>(
src: &[T],
dst: &mut [T],
d1: usize,
d2: usize,
src_stride1: usize,
dst_stride1: usize,
src_offset: usize,
dst_offset: usize,
) {
for i1 in 0..d1 {
let dst_idx = i1 * dst_stride1 + dst_offset;
let src_idx = i1 * src_stride1 + src_offset;
let dst = &mut dst[dst_idx..dst_idx + d2];
let src = &src[src_idx..src_idx + d2];
dst.copy_from_slice(src)
}
}
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
match src_l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
@ -2452,6 +2472,48 @@ impl BackendStorage for CpuStorage {
}
}
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
(Self::U32(src), Self::U32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::I64(src), Self::I64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::BF16(src), Self::BF16(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F16(src), Self::F16(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F32(src), Self::F32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F64(src), Self::F64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(_, dst) => {
return Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: dst.dtype(),
op: "copy2d",
}
.bt());
}
}
Ok(())
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),

View File

@ -2145,6 +2145,67 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
let dev = &self.device;
let d1 = d1 as u32;
let d2 = d2 as u32;
let dst_s = dst_s as u32;
let src_s = src_s as u32;
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
(S::U8(s), S::U8(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_u8",
),
(S::U32(s), S::U32(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_u32",
),
(S::I64(s), S::I64(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_i64",
),
(S::BF16(s), S::BF16(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_bf16",
),
(S::F16(s), S::F16(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_f16",
),
(S::F32(s), S::F32(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_f32",
),
(S::F64(s), S::F64(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_f64",
),
_ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?,
};
let func = dev.get_or_load_func(kname, kernels::FILL)?;
let cfg = LaunchConfig::for_num_elems(d1 * d2);
let params = (src, dst, d1, d2, src_s, dst_s);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let dims = src_shape.dims();

View File

@ -154,6 +154,19 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -166,6 +166,19 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -67,6 +67,7 @@ pub mod shape;
mod storage;
mod strided_index;
mod tensor;
mod tensor_cat;
pub mod test_utils;
pub mod utils;
mod variable;

View File

@ -422,6 +422,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "powf_f32",
DType::F16 => "powf_f16",
DType::BF16 => "powf_bf16",
dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"),
};
candle_metal_kernels::call_powf(
@ -439,6 +440,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "powf_f32_strided",
DType::F16 => "powf_f16_strided",
DType::BF16 => "powf_bf16_strided",
dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"),
};
candle_metal_kernels::call_powf_strided(
@ -471,6 +473,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "elu_f32",
DType::F16 => "elu_f16",
DType::BF16 => "elu_bf16",
dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"),
};
candle_metal_kernels::call_elu(
@ -488,6 +491,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "elu_f32_strided",
DType::F16 => "elu_f16_strided",
DType::BF16 => "elu_bf16_strided",
dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"),
};
candle_metal_kernels::call_elu_strided(
@ -1292,6 +1296,67 @@ impl BackendStorage for MetalStorage {
))
}
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
if self.dtype() != dst.dtype() {
crate::bail!(
"copy2d with inconsistent dtypes {:?} {:?}",
self.dtype(),
dst.dtype()
)
}
let command_buffer = self.device.command_buffer()?;
if src_s == d2 && dst_s == d2 {
command_buffer.set_label("copy2d_contiguous");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("copy2d_contiguous");
let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger;
let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
blit.end_encoding();
} else {
let el_count = d1 * d2;
if el_count == 0 {
return Ok(());
}
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::copy2d::FLOAT,
DType::F16 => candle_metal_kernels::copy2d::HALF,
DType::BF16 => candle_metal_kernels::copy2d::BFLOAT,
DType::I64 => candle_metal_kernels::copy2d::I64,
DType::U32 => candle_metal_kernels::copy2d::U32,
DType::U8 => candle_metal_kernels::copy2d::U8,
dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"),
};
candle_metal_kernels::call_copy2d(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
&self.buffer,
&dst.buffer,
d1,
d2,
src_s,
dst_s,
src_o * self.dtype.size_in_bytes(),
dst_o * self.dtype.size_in_bytes(),
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy2d");
}
Ok(())
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer()?;
if src_l.is_contiguous() && self.dtype == dst.dtype() {

View File

@ -701,4 +701,32 @@ impl Storage {
.bt()),
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
(Self::Cuda(src), Self::Cuda(dst)) => {
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
}
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "copy2d",
}
.bt()),
}
}
}

View File

@ -666,7 +666,7 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
if dim >= self.dims().len() {
Err(Error::DimOutOfRange {
shape: self.shape().clone(),
@ -2149,152 +2149,6 @@ impl Tensor {
Self::cat(&args, dim)
}
/// Concatenates two or more tensors along a particular dimension.
///
/// All tensors must of the same rank, and the output will have
/// the same rank
///
/// ```rust
/// # use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = Tensor::cat(&[&a, &b], 0)?;
/// assert_eq!(c.shape().dims(), &[4, 3]);
///
/// let c = Tensor::cat(&[&a, &b], 1)?;
/// assert_eq!(c.shape().dims(), &[2, 6]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let dim = dim.to_index(arg0.shape(), "cat")?;
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 {
Self::cat0(args)
} else {
// TODO: Avoid these transpositions and have an implementation that works
// for dim != 0...
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
let device = arg0.device();
let dtype = arg0.dtype();
let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[0] = 0;
let mut offsets = vec![0usize];
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
}
.bt())?
}
if arg.device().location() != device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx == 0 {
cat_dims[0] += v2;
}
if dim_idx != 0 && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
.copy_strided_src(&mut storage, offset, arg.layout())?;
}
Ok(from_storage(storage, shape, op, false))
}
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after.
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {

View File

@ -0,0 +1,240 @@
use crate::{shape::Dim, Error, Result, Shape, Tensor};
impl Tensor {
/// Concatenates two or more tensors along a particular dimension.
///
/// All tensors must of the same rank, and the output will have
/// the same rank
///
/// ```rust
/// # use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = Tensor::cat(&[&a, &b], 0)?;
/// assert_eq!(c.shape().dims(), &[4, 3]);
///
/// let c = Tensor::cat(&[&a, &b], 1)?;
/// assert_eq!(c.shape().dims(), &[2, 6]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let dim = dim.to_index(arg0.shape(), "cat")?;
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 {
Self::cat0(args)
} else {
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
if all_contiguous {
Self::cat_contiguous(args, dim)
} else {
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}
}
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
let device = arg0.device();
let dtype = arg0.dtype();
let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[0] = 0;
let mut offsets = vec![0usize];
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
}
.bt())?
}
if arg.device().location() != device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx == 0 {
cat_dims[0] += v2;
}
if dim_idx != 0 && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
.copy_strided_src(&mut storage, offset, arg.layout())?;
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
let device = arg0.device();
let dtype = arg0.dtype();
let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[dim] = 0;
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
}
.bt())?
}
if arg.device().location() != device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx == dim {
cat_dims[dim] += v2;
}
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
let cat_target_dim_len = cat_dims[dim];
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
let mut storage = device.zeros(&shape, dtype)?;
let mut dst_o = 0;
for arg in args.iter() {
let arg = arg.as_ref();
let arg_dims = arg.shape().dims();
let d1: usize = arg_dims.iter().take(dim).product();
let d2 = block_size * arg_dims[dim];
let dst_s = block_size * cat_target_dim_len;
let src_o = arg.layout().start_offset();
arg.storage().copy2d(
&mut storage,
d1,
d2,
/* src_s */ d2,
dst_s,
src_o,
dst_o,
)?;
dst_o += d2;
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
}