mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Tensor mutability (#154)
* Working towards tensor mutability. * Use a ref-cell to provide tensor mutability.
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::shape::Dim;
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::cell::RefCell;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -18,10 +19,23 @@ impl TensorId {
|
||||
|
||||
pub struct Tensor_ {
|
||||
id: TensorId,
|
||||
storage: Arc<Storage>,
|
||||
// Storage uses a mutex here so inner mutability is available and borrow rules are checked
|
||||
// dynamically. The alternatives would be:
|
||||
// - Using a mutex, this would have the highest cost when retrieving the storage but would
|
||||
// prevent errors when concurrent access takes place. Mutex would also be subject to
|
||||
// deadlocks for example using the current code if the same tensor is used twice by a single
|
||||
// binary op.
|
||||
// - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
|
||||
// accesses.
|
||||
// Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
|
||||
// and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
|
||||
// that's tricky to encode in the current setup.
|
||||
storage: Arc<RefCell<Storage>>,
|
||||
layout: Layout,
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl AsRef<Tensor> for Tensor {
|
||||
@ -62,7 +76,7 @@ macro_rules! unary_op {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self
|
||||
.storage
|
||||
.storage()?
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::$op_name(self.clone()))
|
||||
@ -78,8 +92,8 @@ macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage = self.storage.binary_impl::<crate::op::$op_name>(
|
||||
&rhs.storage,
|
||||
let storage = self.storage()?.binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage()?,
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
@ -119,12 +133,16 @@ fn from_storage<S: Into<Shape>>(
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
) -> Tensor {
|
||||
let dtype = storage.dtype();
|
||||
let device = storage.device();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(storage),
|
||||
storage: Arc::new(RefCell::new(storage)),
|
||||
layout: Layout::contiguous(shape),
|
||||
op,
|
||||
is_variable,
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
@ -169,7 +187,7 @@ impl Tensor {
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn ones_like(&self) -> Result<Self> {
|
||||
Tensor::ones(self.shape(), self.dtype(), &self.device())
|
||||
Tensor::ones(self.shape(), self.dtype(), self.device())
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with zeros.
|
||||
@ -219,7 +237,7 @@ impl Tensor {
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn zeros_like(&self) -> Result<Self> {
|
||||
Tensor::zeros(self.shape(), self.dtype(), &self.device())
|
||||
Tensor::zeros(self.shape(), self.dtype(), self.device())
|
||||
}
|
||||
|
||||
fn rand_impl<S: Into<Shape>>(
|
||||
@ -502,7 +520,7 @@ impl Tensor {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok::<_, Error>(data[self.layout().start_offset()])
|
||||
};
|
||||
match self.storage.as_ref() {
|
||||
match &*self.storage()? {
|
||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
@ -520,7 +538,7 @@ impl Tensor {
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
let storage = self.storage.affine(self.layout(), mul, add)?;
|
||||
let storage = self.storage()?.affine(self.layout(), mul, add)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Affine {
|
||||
arg: self.clone(),
|
||||
@ -535,7 +553,7 @@ impl Tensor {
|
||||
|
||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||
let storage = self.storage.elu(self.layout(), alpha)?;
|
||||
let storage = self.storage()?.elu(self.layout(), alpha)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Elu(self.clone(), alpha))
|
||||
} else {
|
||||
@ -585,6 +603,8 @@ impl Tensor {
|
||||
layout,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -616,7 +636,9 @@ impl Tensor {
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.storage.unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
let mut storage = self
|
||||
.storage()?
|
||||
.unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = if self.track_op() {
|
||||
@ -649,7 +671,7 @@ impl Tensor {
|
||||
for &dim in sum_dims {
|
||||
self.check_dim(dim, "sum")?;
|
||||
}
|
||||
let storage = self.storage.sum(self.layout(), sum_dims)?;
|
||||
let storage = self.storage()?.sum(self.layout(), sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
@ -695,8 +717,8 @@ impl Tensor {
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage
|
||||
.conv1d(self.layout(), &kernel.storage, kernel.layout(), ¶ms)?;
|
||||
self.storage()?
|
||||
.conv1d(self.layout(), &*kernel.storage()?, kernel.layout(), ¶ms)?;
|
||||
let op = if self.track_op() || kernel.track_op() {
|
||||
Some(Op::Conv1D {
|
||||
arg: self.clone(),
|
||||
@ -749,8 +771,8 @@ impl Tensor {
|
||||
})?
|
||||
}
|
||||
|
||||
let storage = self.storage.matmul(
|
||||
&rhs.storage,
|
||||
let storage = self.storage()?.matmul(
|
||||
&*rhs.storage()?,
|
||||
(batching, m, n, k),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
@ -769,11 +791,11 @@ impl Tensor {
|
||||
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
|
||||
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
||||
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
||||
let storage = self.storage.where_cond(
|
||||
let storage = self.storage()?.where_cond(
|
||||
self.layout(),
|
||||
&on_true.storage,
|
||||
&*on_true.storage()?,
|
||||
on_true.layout(),
|
||||
&on_false.storage,
|
||||
&*on_false.storage()?,
|
||||
on_false.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
||||
@ -821,8 +843,8 @@ impl Tensor {
|
||||
let seq_len = ids_shape.r1()?;
|
||||
let (_, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids
|
||||
.storage
|
||||
.embedding(ids.layout(), &rhs.storage, rhs.layout())?;
|
||||
.storage()?
|
||||
.embedding(ids.layout(), &*rhs.storage()?, rhs.layout())?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = if ids.track_op() || rhs.track_op() {
|
||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||
@ -836,23 +858,6 @@ impl Tensor {
|
||||
self.layout.strided_index()
|
||||
}
|
||||
|
||||
/// Returns data from the underlying storage, this does not take the strides
|
||||
/// into account so the size of the resulting buffer might be larger than the
|
||||
/// tensor number of elements.
|
||||
pub fn storage_data<S: crate::WithDType>(&self) -> Result<std::borrow::Cow<[S]>> {
|
||||
match self.storage.as_ref() {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let slice = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(std::borrow::Cow::Borrowed(slice))
|
||||
}
|
||||
Storage::Cuda(slice) => {
|
||||
let cpu_storage = slice.to_cpu_storage()?;
|
||||
let storage_data = S::cpu_storage_data(cpu_storage)?;
|
||||
Ok(std::borrow::Cow::Owned(storage_data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the data contained in a 1D tensor as a vector of scalar values.
|
||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||
if self.rank() != 1 {
|
||||
@ -862,7 +867,7 @@ impl Tensor {
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
match self.storage.as_ref() {
|
||||
match &*self.storage()? {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(self.strided_index().map(|i| data[i]).collect())
|
||||
@ -890,7 +895,7 @@ impl Tensor {
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(rows)
|
||||
};
|
||||
match self.storage.as_ref() {
|
||||
match &*self.storage()? {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
@ -914,7 +919,7 @@ impl Tensor {
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(top_rows)
|
||||
};
|
||||
match self.storage.as_ref() {
|
||||
match &*self.storage()? {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
@ -922,12 +927,12 @@ impl Tensor {
|
||||
|
||||
/// The dtype for the elements stored in the input tensor.
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.storage.dtype()
|
||||
self.dtype
|
||||
}
|
||||
|
||||
/// The device on which the input tensor is located.
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// The tensor shape, i.e. dimension sizes on each axis.
|
||||
@ -1114,6 +1119,8 @@ impl Tensor {
|
||||
layout: self.layout.transpose(dim1, dim2)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1133,10 +1140,12 @@ impl Tensor {
|
||||
pub fn copy(&self) -> Result<Tensor> {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(self.storage.try_clone(self.layout())?),
|
||||
storage: Arc::new(RefCell::new(self.storage()?.try_clone(self.layout())?)),
|
||||
layout: self.layout.clone(),
|
||||
op: None, // TODO
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1150,6 +1159,8 @@ impl Tensor {
|
||||
layout: self.layout.clone(),
|
||||
op: None,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1159,7 +1170,7 @@ impl Tensor {
|
||||
if self.device().same_device(device) {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let storage = match (self.storage.as_ref(), device) {
|
||||
let storage = match (&*self.storage()?, device) {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
@ -1179,10 +1190,12 @@ impl Tensor {
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(storage),
|
||||
storage: Arc::new(RefCell::new(storage)),
|
||||
layout: self.layout.clone(),
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1216,6 +1229,8 @@ impl Tensor {
|
||||
layout: self.layout.broadcast_as(shape)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1240,7 +1255,7 @@ impl Tensor {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.to_dtype(self.layout(), dtype)?;
|
||||
let storage = self.storage()?.to_dtype(self.layout(), dtype)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::ToDType(self.clone()))
|
||||
} else {
|
||||
@ -1258,7 +1273,7 @@ impl Tensor {
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
self.storage
|
||||
self.storage()?
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(
|
||||
storage,
|
||||
@ -1307,11 +1322,13 @@ impl Tensor {
|
||||
layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
self.storage
|
||||
self.storage()?
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
@ -1507,11 +1524,15 @@ impl Tensor {
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage
|
||||
arg.storage()?
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
fn storage(&self) -> Result<std::cell::Ref<'_, Storage>> {
|
||||
Ok(self.storage.try_borrow()?)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
Reference in New Issue
Block a user