Tensor mutability (#154)

* Working towards tensor mutability.

* Use a ref-cell to provide tensor mutability.
This commit is contained in:
Laurent Mazare
2023-07-13 11:04:40 +01:00
committed by GitHub
parent a3663ce2f2
commit 50b0946a2d
14 changed files with 124 additions and 88 deletions

View File

@ -1,6 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::shape::Dim;
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::cell::RefCell;
use std::sync::Arc;
/// Unique identifier for tensors.
@ -18,10 +19,23 @@ impl TensorId {
pub struct Tensor_ {
id: TensorId,
storage: Arc<Storage>,
// Storage uses a mutex here so inner mutability is available and borrow rules are checked
// dynamically. The alternatives would be:
// - Using a mutex, this would have the highest cost when retrieving the storage but would
// prevent errors when concurrent access takes place. Mutex would also be subject to
// deadlocks for example using the current code if the same tensor is used twice by a single
// binary op.
// - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
// accesses.
// Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
// and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
// that's tricky to encode in the current setup.
storage: Arc<RefCell<Storage>>,
layout: Layout,
op: Option<Op>,
is_variable: bool,
dtype: DType,
device: Device,
}
impl AsRef<Tensor> for Tensor {
@ -62,7 +76,7 @@ macro_rules! unary_op {
pub fn $fn_name(&self) -> Result<Self> {
let shape = self.shape();
let storage = self
.storage
.storage()?
.unary_impl::<crate::op::$op_name>(self.layout())?;
let op = if self.track_op() {
Some(Op::$op_name(self.clone()))
@ -78,8 +92,8 @@ macro_rules! binary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
let storage = self.storage.binary_impl::<crate::op::$op_name>(
&rhs.storage,
let storage = self.storage()?.binary_impl::<crate::op::$op_name>(
&*rhs.storage()?,
self.layout(),
rhs.layout(),
)?;
@ -119,12 +133,16 @@ fn from_storage<S: Into<Shape>>(
op: Option<Op>,
is_variable: bool,
) -> Tensor {
let dtype = storage.dtype();
let device = storage.device();
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(storage),
storage: Arc::new(RefCell::new(storage)),
layout: Layout::contiguous(shape),
op,
is_variable,
dtype,
device,
};
Tensor(Arc::new(tensor_))
}
@ -169,7 +187,7 @@ impl Tensor {
/// # Ok::<(), candle::Error>(())
/// ```
pub fn ones_like(&self) -> Result<Self> {
Tensor::ones(self.shape(), self.dtype(), &self.device())
Tensor::ones(self.shape(), self.dtype(), self.device())
}
/// Creates a new tensor filled with zeros.
@ -219,7 +237,7 @@ impl Tensor {
/// # Ok::<(), candle::Error>(())
/// ```
pub fn zeros_like(&self) -> Result<Self> {
Tensor::zeros(self.shape(), self.dtype(), &self.device())
Tensor::zeros(self.shape(), self.dtype(), self.device())
}
fn rand_impl<S: Into<Shape>>(
@ -502,7 +520,7 @@ impl Tensor {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok::<_, Error>(data[self.layout().start_offset()])
};
match self.storage.as_ref() {
match &*self.storage()? {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
@ -520,7 +538,7 @@ impl Tensor {
/// # Ok::<(), candle::Error>(())
/// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
let storage = self.storage.affine(self.layout(), mul, add)?;
let storage = self.storage()?.affine(self.layout(), mul, add)?;
let op = if self.track_op() {
Some(Op::Affine {
arg: self.clone(),
@ -535,7 +553,7 @@ impl Tensor {
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
pub fn elu(&self, alpha: f64) -> Result<Self> {
let storage = self.storage.elu(self.layout(), alpha)?;
let storage = self.storage()?.elu(self.layout(), alpha)?;
let op = if self.track_op() {
Some(Op::Elu(self.clone(), alpha))
} else {
@ -585,6 +603,8 @@ impl Tensor {
layout,
op,
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@ -616,7 +636,9 @@ impl Tensor {
exp.broadcast_div(&sum_exp)
} else {
let shape = self.shape();
let mut storage = self.storage.unary_impl::<crate::op::Exp>(self.layout())?;
let mut storage = self
.storage()?
.unary_impl::<crate::op::Exp>(self.layout())?;
// The resulting storage is contiguous.
storage.divide_by_sum_over_dim(shape, dim)?;
let op = if self.track_op() {
@ -649,7 +671,7 @@ impl Tensor {
for &dim in sum_dims {
self.check_dim(dim, "sum")?;
}
let storage = self.storage.sum(self.layout(), sum_dims)?;
let storage = self.storage()?.sum(self.layout(), sum_dims)?;
let op = if self.track_op() {
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
} else {
@ -695,8 +717,8 @@ impl Tensor {
stride,
};
let storage =
self.storage
.conv1d(self.layout(), &kernel.storage, kernel.layout(), &params)?;
self.storage()?
.conv1d(self.layout(), &*kernel.storage()?, kernel.layout(), &params)?;
let op = if self.track_op() || kernel.track_op() {
Some(Op::Conv1D {
arg: self.clone(),
@ -749,8 +771,8 @@ impl Tensor {
})?
}
let storage = self.storage.matmul(
&rhs.storage,
let storage = self.storage()?.matmul(
&*rhs.storage()?,
(batching, m, n, k),
self.layout(),
rhs.layout(),
@ -769,11 +791,11 @@ impl Tensor {
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
let storage = self.storage.where_cond(
let storage = self.storage()?.where_cond(
self.layout(),
&on_true.storage,
&*on_true.storage()?,
on_true.layout(),
&on_false.storage,
&*on_false.storage()?,
on_false.layout(),
)?;
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
@ -821,8 +843,8 @@ impl Tensor {
let seq_len = ids_shape.r1()?;
let (_, hidden_size) = rhs.shape().r2()?;
let storage = ids
.storage
.embedding(ids.layout(), &rhs.storage, rhs.layout())?;
.storage()?
.embedding(ids.layout(), &*rhs.storage()?, rhs.layout())?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))
@ -836,23 +858,6 @@ impl Tensor {
self.layout.strided_index()
}
/// Returns data from the underlying storage, this does not take the strides
/// into account so the size of the resulting buffer might be larger than the
/// tensor number of elements.
pub fn storage_data<S: crate::WithDType>(&self) -> Result<std::borrow::Cow<[S]>> {
match self.storage.as_ref() {
Storage::Cpu(cpu_storage) => {
let slice = S::cpu_storage_as_slice(cpu_storage)?;
Ok(std::borrow::Cow::Borrowed(slice))
}
Storage::Cuda(slice) => {
let cpu_storage = slice.to_cpu_storage()?;
let storage_data = S::cpu_storage_data(cpu_storage)?;
Ok(std::borrow::Cow::Owned(storage_data))
}
}
}
/// Returns the data contained in a 1D tensor as a vector of scalar values.
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
if self.rank() != 1 {
@ -862,7 +867,7 @@ impl Tensor {
shape: self.shape().clone(),
});
}
match self.storage.as_ref() {
match &*self.storage()? {
Storage::Cpu(cpu_storage) => {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect())
@ -890,7 +895,7 @@ impl Tensor {
assert!(src_index.next().is_none());
Ok(rows)
};
match self.storage.as_ref() {
match &*self.storage()? {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
@ -914,7 +919,7 @@ impl Tensor {
assert!(src_index.next().is_none());
Ok(top_rows)
};
match self.storage.as_ref() {
match &*self.storage()? {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
@ -922,12 +927,12 @@ impl Tensor {
/// The dtype for the elements stored in the input tensor.
pub fn dtype(&self) -> DType {
self.storage.dtype()
self.dtype
}
/// The device on which the input tensor is located.
pub fn device(&self) -> Device {
self.storage.device()
pub fn device(&self) -> &Device {
&self.device
}
/// The tensor shape, i.e. dimension sizes on each axis.
@ -1114,6 +1119,8 @@ impl Tensor {
layout: self.layout.transpose(dim1, dim2)?,
op,
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@ -1133,10 +1140,12 @@ impl Tensor {
pub fn copy(&self) -> Result<Tensor> {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(self.storage.try_clone(self.layout())?),
storage: Arc::new(RefCell::new(self.storage()?.try_clone(self.layout())?)),
layout: self.layout.clone(),
op: None, // TODO
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@ -1150,6 +1159,8 @@ impl Tensor {
layout: self.layout.clone(),
op: None,
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@ -1159,7 +1170,7 @@ impl Tensor {
if self.device().same_device(device) {
Ok(self.clone())
} else {
let storage = match (self.storage.as_ref(), device) {
let storage = match (&*self.storage()?, device) {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
@ -1179,10 +1190,12 @@ impl Tensor {
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(storage),
storage: Arc::new(RefCell::new(storage)),
layout: self.layout.clone(),
op,
is_variable: false,
dtype: self.dtype,
device: device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@ -1216,6 +1229,8 @@ impl Tensor {
layout: self.layout.broadcast_as(shape)?,
op,
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@ -1240,7 +1255,7 @@ impl Tensor {
Ok(self.clone())
} else {
let shape = self.shape();
let storage = self.storage.to_dtype(self.layout(), dtype)?;
let storage = self.storage()?.to_dtype(self.layout(), dtype)?;
let op = if self.track_op() {
Some(Op::ToDType(self.clone()))
} else {
@ -1258,7 +1273,7 @@ impl Tensor {
} else {
let shape = self.shape();
let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage
self.storage()?
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(
storage,
@ -1307,11 +1322,13 @@ impl Tensor {
layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
op,
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
} else {
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage
self.storage()?
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, op, false))
}
@ -1507,11 +1524,15 @@ impl Tensor {
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage
arg.storage()?
.copy_strided_src(&mut storage, offset, arg.layout())?;
}
Ok(from_storage(storage, shape, op, false))
}
fn storage(&self) -> Result<std::cell::Ref<'_, Storage>> {
Ok(self.storage.try_borrow()?)
}
}
macro_rules! bin_trait {