Use broadcasted scalars for const tensors.

This commit is contained in:
laurent
2023-06-29 11:56:40 +01:00
parent 3872dc4751
commit 2741b39ad3
5 changed files with 12 additions and 14 deletions

View File

@ -115,16 +115,14 @@ fn from_storage<S: Into<Shape>>(
}
impl Tensor {
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
fn ones_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
let storage = device.ones(&shape, dtype)?;
Ok(from_storage(storage, shape, None, is_variable))
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
}
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
@ -132,6 +130,8 @@ impl Tensor {
}
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
// Maybe we should allocate some actual storage for vars rather than just using a
// broadcasted scalar?
Self::ones_impl(shape, dtype, device, true)
}
@ -139,16 +139,14 @@ impl Tensor {
Tensor::ones(self.shape(), self.dtype(), &self.device())
}
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
let storage = device.zeros(&shape, dtype)?;
Ok(from_storage(storage, shape, None, is_variable))
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
}
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
@ -599,8 +597,7 @@ impl Tensor {
&self.layout
}
// TODO: Rename to `stride` once the PR that introduced the layout has been merged.
pub fn stride_tmp(&self) -> &[usize] {
pub fn stride(&self) -> &[usize] {
self.layout.stride()
}