Use broadcasted scalars for const tensors.

This commit is contained in:
laurent
2023-06-29 11:56:40 +01:00
parent 3872dc4751
commit 2741b39ad3
5 changed files with 12 additions and 14 deletions

View File

@ -99,7 +99,7 @@ impl Tensor {
pub fn backward(&self) -> Result<GradStore> { pub fn backward(&self) -> Result<GradStore> {
let sorted_nodes = self.sorted_nodes(); let sorted_nodes = self.sorted_nodes();
let mut grads = GradStore::new(); let mut grads = GradStore::new();
grads.insert(self, self.ones_like()?); grads.insert(self, self.ones_like()?.contiguous()?);
for node in sorted_nodes.iter() { for node in sorted_nodes.iter() {
if node.is_variable() { if node.is_variable() {
continue; continue;

View File

@ -3,6 +3,8 @@ use crate::{Error, Result};
#[derive(Clone, PartialEq, Eq)] #[derive(Clone, PartialEq, Eq)]
pub struct Shape(Vec<usize>); pub struct Shape(Vec<usize>);
pub const SCALAR: Shape = Shape(vec![]);
impl std::fmt::Debug for Shape { impl std::fmt::Debug for Shape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", &self.dims()) write!(f, "{:?}", &self.dims())

View File

@ -115,16 +115,14 @@ fn from_storage<S: Into<Shape>>(
} }
impl Tensor { impl Tensor {
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
fn ones_impl<S: Into<Shape>>( fn ones_impl<S: Into<Shape>>(
shape: S, shape: S,
dtype: DType, dtype: DType,
device: &Device, device: &Device,
is_variable: bool, is_variable: bool,
) -> Result<Self> { ) -> Result<Self> {
let shape = shape.into(); let storage = device.ones(&crate::shape::SCALAR, dtype)?;
let storage = device.ones(&shape, dtype)?; from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
Ok(from_storage(storage, shape, None, is_variable))
} }
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
@ -132,6 +130,8 @@ impl Tensor {
} }
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
// Maybe we should allocate some actual storage for vars rather than just using a
// broadcasted scalar?
Self::ones_impl(shape, dtype, device, true) Self::ones_impl(shape, dtype, device, true)
} }
@ -139,16 +139,14 @@ impl Tensor {
Tensor::ones(self.shape(), self.dtype(), &self.device()) Tensor::ones(self.shape(), self.dtype(), &self.device())
} }
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
fn zeros_impl<S: Into<Shape>>( fn zeros_impl<S: Into<Shape>>(
shape: S, shape: S,
dtype: DType, dtype: DType,
device: &Device, device: &Device,
is_variable: bool, is_variable: bool,
) -> Result<Self> { ) -> Result<Self> {
let shape = shape.into(); let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
let storage = device.zeros(&shape, dtype)?; from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
Ok(from_storage(storage, shape, None, is_variable))
} }
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
@ -599,8 +597,7 @@ impl Tensor {
&self.layout &self.layout
} }
// TODO: Rename to `stride` once the PR that introduced the layout has been merged. pub fn stride(&self) -> &[usize] {
pub fn stride_tmp(&self) -> &[usize] {
self.layout.stride() self.layout.stride()
} }

View File

@ -20,7 +20,6 @@ fn matmul_grad(device: &Device) -> Result<()> {
let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?; let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?;
let data: Vec<_> = (0..12).map(|i| i as f32).collect(); let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?; let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?;
let c = x.matmul(&y)?; let c = x.matmul(&y)?;
let grads = c.backward()?; let grads = c.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?; let grad_x = grads.get(&x).context("no grad for x")?;

View File

@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> {
let a_tt = a.t()?.contiguous()?.t()?; let a_tt = a.t()?.contiguous()?.t()?;
assert!(!a_tt.is_contiguous()); assert!(!a_tt.is_contiguous());
assert_eq!(a.dims(), a_tt.dims()); assert_eq!(a.dims(), a_tt.dims());
assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]); assert_eq!(a_tt.stride(), &[6, 1, 2]);
let b_tt = b.t()?.contiguous()?.t()?; let b_tt = b.t()?.contiguous()?.t()?;
assert!(!b_tt.is_contiguous()); assert!(!b_tt.is_contiguous());
assert_eq!(b.dims(), b_tt.dims()); assert_eq!(b.dims(), b_tt.dims());
assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]); assert_eq!(b_tt.stride(), &[6, 1, 3]);
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected); assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected); assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);