mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Use a rwlock for inner mutability. (#156)
* Use a rw-lock. * Make clippy happier.
This commit is contained in:
@ -5,7 +5,7 @@ members = [
|
||||
"candle-kernels",
|
||||
"candle-hub",
|
||||
"candle-nn",
|
||||
# "candle-pyo3",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
]
|
||||
|
||||
|
@ -155,12 +155,6 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
SafeTensor(#[from] safetensors::SafeTensorError),
|
||||
|
||||
// Maybe we could have a more detailed error here, including the line of the function that
|
||||
// triggered this or some backtrace.
|
||||
/// Borrow error.
|
||||
#[error(transparent)]
|
||||
BorrowError(#[from] std::cell::BorrowError),
|
||||
|
||||
#[error("unsupported safetensor dtype {0:?}")]
|
||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::shape::Dim;
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::cell::RefCell;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
@ -31,7 +30,7 @@ pub struct Tensor_ {
|
||||
// Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
|
||||
// and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
|
||||
// that's tricky to encode in the current setup.
|
||||
storage: Arc<RefCell<Storage>>,
|
||||
storage: Arc<RwLock<Storage>>,
|
||||
layout: Layout,
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
@ -77,7 +76,7 @@ macro_rules! unary_op {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self
|
||||
.storage()?
|
||||
.storage()
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::$op_name(self.clone()))
|
||||
@ -93,8 +92,8 @@ macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage = self.storage()?.binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage()?,
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
@ -138,7 +137,7 @@ fn from_storage<S: Into<Shape>>(
|
||||
let device = storage.device();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(RefCell::new(storage)),
|
||||
storage: Arc::new(RwLock::new(storage)),
|
||||
layout: Layout::contiguous(shape),
|
||||
op,
|
||||
is_variable,
|
||||
@ -521,7 +520,7 @@ impl Tensor {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok::<_, Error>(data[self.layout().start_offset()])
|
||||
};
|
||||
match &*self.storage()? {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
@ -539,7 +538,7 @@ impl Tensor {
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
let storage = self.storage()?.affine(self.layout(), mul, add)?;
|
||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Affine {
|
||||
arg: self.clone(),
|
||||
@ -554,7 +553,7 @@ impl Tensor {
|
||||
|
||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||
let storage = self.storage()?.elu(self.layout(), alpha)?;
|
||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Elu(self.clone(), alpha))
|
||||
} else {
|
||||
@ -637,9 +636,7 @@ impl Tensor {
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self
|
||||
.storage()?
|
||||
.unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = if self.track_op() {
|
||||
@ -672,7 +669,7 @@ impl Tensor {
|
||||
for &dim in sum_dims {
|
||||
self.check_dim(dim, "sum")?;
|
||||
}
|
||||
let storage = self.storage()?.sum(self.layout(), sum_dims)?;
|
||||
let storage = self.storage().sum(self.layout(), sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
@ -718,8 +715,8 @@ impl Tensor {
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage()?
|
||||
.conv1d(self.layout(), &*kernel.storage()?, kernel.layout(), ¶ms)?;
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = if self.track_op() || kernel.track_op() {
|
||||
Some(Op::Conv1D {
|
||||
arg: self.clone(),
|
||||
@ -772,8 +769,8 @@ impl Tensor {
|
||||
})?
|
||||
}
|
||||
|
||||
let storage = self.storage()?.matmul(
|
||||
&*rhs.storage()?,
|
||||
let storage = self.storage().matmul(
|
||||
&rhs.storage(),
|
||||
(batching, m, n, k),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
@ -792,11 +789,11 @@ impl Tensor {
|
||||
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
|
||||
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
||||
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
||||
let storage = self.storage()?.where_cond(
|
||||
let storage = self.storage().where_cond(
|
||||
self.layout(),
|
||||
&*on_true.storage()?,
|
||||
&on_true.storage(),
|
||||
on_true.layout(),
|
||||
&*on_false.storage()?,
|
||||
&on_false.storage(),
|
||||
on_false.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
||||
@ -844,8 +841,8 @@ impl Tensor {
|
||||
let seq_len = ids_shape.r1()?;
|
||||
let (_, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids
|
||||
.storage()?
|
||||
.embedding(ids.layout(), &*rhs.storage()?, rhs.layout())?;
|
||||
.storage()
|
||||
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = if ids.track_op() || rhs.track_op() {
|
||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||
@ -868,7 +865,7 @@ impl Tensor {
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
match &*self.storage()? {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(self.strided_index().map(|i| data[i]).collect())
|
||||
@ -896,7 +893,7 @@ impl Tensor {
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(rows)
|
||||
};
|
||||
match &*self.storage()? {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
@ -920,7 +917,7 @@ impl Tensor {
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(top_rows)
|
||||
};
|
||||
match &*self.storage()? {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
@ -1141,7 +1138,7 @@ impl Tensor {
|
||||
pub fn copy(&self) -> Result<Tensor> {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(RefCell::new(self.storage()?.try_clone(self.layout())?)),
|
||||
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
|
||||
layout: self.layout.clone(),
|
||||
op: None, // TODO
|
||||
is_variable: false,
|
||||
@ -1171,7 +1168,7 @@ impl Tensor {
|
||||
if self.device().same_device(device) {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let storage = match (&*self.storage()?, device) {
|
||||
let storage = match (&*self.storage(), device) {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
@ -1191,7 +1188,7 @@ impl Tensor {
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(RefCell::new(storage)),
|
||||
storage: Arc::new(RwLock::new(storage)),
|
||||
layout: self.layout.clone(),
|
||||
op,
|
||||
is_variable: false,
|
||||
@ -1256,7 +1253,7 @@ impl Tensor {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage()?.to_dtype(self.layout(), dtype)?;
|
||||
let storage = self.storage().to_dtype(self.layout(), dtype)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::ToDType(self.clone()))
|
||||
} else {
|
||||
@ -1274,7 +1271,7 @@ impl Tensor {
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
self.storage()?
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(
|
||||
storage,
|
||||
@ -1329,7 +1326,7 @@ impl Tensor {
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
self.storage()?
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
@ -1525,14 +1522,14 @@ impl Tensor {
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()?
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
fn storage(&self) -> Result<std::cell::Ref<'_, Storage>> {
|
||||
Ok(self.storage.try_borrow()?)
|
||||
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user