mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Tensor mutability (#154)
* Working towards tensor mutability. * Use a ref-cell to provide tensor mutability.
This commit is contained in:
@ -5,7 +5,7 @@ members = [
|
||||
"candle-kernels",
|
||||
"candle-hub",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
# "candle-pyo3",
|
||||
"candle-transformers",
|
||||
]
|
||||
|
||||
|
@ -222,7 +222,7 @@ impl Tensor {
|
||||
} else {
|
||||
let mut dims = arg_dims.to_vec();
|
||||
dims[dim] = start_idx;
|
||||
Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
|
||||
Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
|
||||
};
|
||||
let right_pad = arg_dims[dim] - start_idx - len;
|
||||
let right_pad = if right_pad == 0 {
|
||||
@ -230,7 +230,7 @@ impl Tensor {
|
||||
} else {
|
||||
let mut dims = arg_dims.to_vec();
|
||||
dims[dim] = right_pad;
|
||||
Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
|
||||
Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
|
||||
};
|
||||
let arg_grad = match (left_pad, right_pad) {
|
||||
(None, None) => grad,
|
||||
@ -264,7 +264,7 @@ impl Tensor {
|
||||
}
|
||||
Op::ToDevice(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let arg_grad = grad.to_device(&sum_grad.device())?;
|
||||
let arg_grad = grad.to_device(sum_grad.device())?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Transpose(arg, dim1, dim2) => {
|
||||
|
@ -155,6 +155,12 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
SafeTensor(#[from] safetensors::SafeTensorError),
|
||||
|
||||
// Maybe we could have a more detailed error here, including the line of the function that
|
||||
// triggered this or some backtrace.
|
||||
/// Borrow error.
|
||||
#[error(transparent)]
|
||||
BorrowError(#[from] std::cell::BorrowError),
|
||||
|
||||
#[error("unsupported safetensor dtype {0:?}")]
|
||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||
|
||||
|
@ -40,7 +40,8 @@ impl st::View for Tensor {
|
||||
|
||||
fn data(&self) -> Cow<[u8]> {
|
||||
// This copies data from GPU to CPU.
|
||||
convert_back(self).unwrap()
|
||||
// TODO: Avoid the unwrap here.
|
||||
Cow::Owned(convert_back(self).unwrap())
|
||||
}
|
||||
|
||||
fn data_len(&self) -> usize {
|
||||
@ -86,19 +87,18 @@ fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<T
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_back_<T: WithDType>(value: Cow<'_, [T]>) -> Cow<'_, [u8]> {
|
||||
fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
|
||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||
let length = vs.len() * size_in_bytes;
|
||||
let capacity = vs.capacity() * size_in_bytes;
|
||||
let ptr = vs.as_mut_ptr() as *mut u8;
|
||||
// Don't run the destructor for Vec<T>
|
||||
std::mem::forget(vs);
|
||||
// SAFETY:
|
||||
//
|
||||
// Every T is larger than u8, so there is no issue regarding alignment.
|
||||
// This is safe only because we explicitly take the lifetime from the Cow's lifetime
|
||||
// and consume the original Cow.
|
||||
// This means that borrowed Cow, will keep their lifetime information, preventing
|
||||
// this slice from being accessed after freeing the original memory.
|
||||
let slice = unsafe {
|
||||
std::slice::from_raw_parts(value.as_ptr() as *const u8, value.len() * size_in_bytes)
|
||||
};
|
||||
Cow::Borrowed(slice)
|
||||
// This re-interpret the Vec<T> as a Vec<u8>.
|
||||
unsafe { Vec::from_raw_parts(ptr, length, capacity) }
|
||||
}
|
||||
|
||||
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
@ -113,14 +113,16 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_back(tensor: &Tensor) -> Result<Cow<[u8]>> {
|
||||
pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
|
||||
let tensor = tensor.flatten_all()?;
|
||||
match tensor.dtype() {
|
||||
DType::U8 => Ok(convert_back_::<u8>(tensor.storage_data()?)),
|
||||
DType::U32 => Ok(convert_back_::<u32>(tensor.storage_data()?)),
|
||||
DType::F16 => Ok(convert_back_::<half::f16>(tensor.storage_data()?)),
|
||||
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.storage_data()?)),
|
||||
DType::F32 => Ok(convert_back_::<f32>(tensor.storage_data()?)),
|
||||
DType::F64 => Ok(convert_back_::<f64>(tensor.storage_data()?)),
|
||||
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
|
||||
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
|
||||
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
|
||||
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
|
||||
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
|
||||
DType::F64 => Ok(convert_back_::<f64>(tensor.to_vec1()?)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -183,7 +185,7 @@ mod tests {
|
||||
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
|
||||
t.save_safetensors("t", "t.safetensors").unwrap();
|
||||
let bytes = std::fs::read("t.safetensors").unwrap();
|
||||
assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0");
|
||||
assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
|
||||
std::fs::remove_file("t.safetensors").unwrap();
|
||||
}
|
||||
|
||||
@ -194,7 +196,7 @@ mod tests {
|
||||
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
|
||||
st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap();
|
||||
let bytes = std::fs::read("multi.safetensors").unwrap();
|
||||
assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0");
|
||||
assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
|
||||
std::fs::remove_file("multi.safetensors").unwrap();
|
||||
}
|
||||
}
|
||||
|
@ -216,6 +216,7 @@ impl Dim for usize {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum D {
|
||||
Minus1,
|
||||
Minus2,
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::shape::Dim;
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::cell::RefCell;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -18,10 +19,23 @@ impl TensorId {
|
||||
|
||||
pub struct Tensor_ {
|
||||
id: TensorId,
|
||||
storage: Arc<Storage>,
|
||||
// Storage uses a mutex here so inner mutability is available and borrow rules are checked
|
||||
// dynamically. The alternatives would be:
|
||||
// - Using a mutex, this would have the highest cost when retrieving the storage but would
|
||||
// prevent errors when concurrent access takes place. Mutex would also be subject to
|
||||
// deadlocks for example using the current code if the same tensor is used twice by a single
|
||||
// binary op.
|
||||
// - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
|
||||
// accesses.
|
||||
// Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
|
||||
// and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
|
||||
// that's tricky to encode in the current setup.
|
||||
storage: Arc<RefCell<Storage>>,
|
||||
layout: Layout,
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl AsRef<Tensor> for Tensor {
|
||||
@ -62,7 +76,7 @@ macro_rules! unary_op {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self
|
||||
.storage
|
||||
.storage()?
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::$op_name(self.clone()))
|
||||
@ -78,8 +92,8 @@ macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage = self.storage.binary_impl::<crate::op::$op_name>(
|
||||
&rhs.storage,
|
||||
let storage = self.storage()?.binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage()?,
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
@ -119,12 +133,16 @@ fn from_storage<S: Into<Shape>>(
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
) -> Tensor {
|
||||
let dtype = storage.dtype();
|
||||
let device = storage.device();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(storage),
|
||||
storage: Arc::new(RefCell::new(storage)),
|
||||
layout: Layout::contiguous(shape),
|
||||
op,
|
||||
is_variable,
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
@ -169,7 +187,7 @@ impl Tensor {
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn ones_like(&self) -> Result<Self> {
|
||||
Tensor::ones(self.shape(), self.dtype(), &self.device())
|
||||
Tensor::ones(self.shape(), self.dtype(), self.device())
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with zeros.
|
||||
@ -219,7 +237,7 @@ impl Tensor {
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn zeros_like(&self) -> Result<Self> {
|
||||
Tensor::zeros(self.shape(), self.dtype(), &self.device())
|
||||
Tensor::zeros(self.shape(), self.dtype(), self.device())
|
||||
}
|
||||
|
||||
fn rand_impl<S: Into<Shape>>(
|
||||
@ -502,7 +520,7 @@ impl Tensor {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok::<_, Error>(data[self.layout().start_offset()])
|
||||
};
|
||||
match self.storage.as_ref() {
|
||||
match &*self.storage()? {
|
||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
@ -520,7 +538,7 @@ impl Tensor {
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
let storage = self.storage.affine(self.layout(), mul, add)?;
|
||||
let storage = self.storage()?.affine(self.layout(), mul, add)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Affine {
|
||||
arg: self.clone(),
|
||||
@ -535,7 +553,7 @@ impl Tensor {
|
||||
|
||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||
let storage = self.storage.elu(self.layout(), alpha)?;
|
||||
let storage = self.storage()?.elu(self.layout(), alpha)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Elu(self.clone(), alpha))
|
||||
} else {
|
||||
@ -585,6 +603,8 @@ impl Tensor {
|
||||
layout,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -616,7 +636,9 @@ impl Tensor {
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.storage.unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
let mut storage = self
|
||||
.storage()?
|
||||
.unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = if self.track_op() {
|
||||
@ -649,7 +671,7 @@ impl Tensor {
|
||||
for &dim in sum_dims {
|
||||
self.check_dim(dim, "sum")?;
|
||||
}
|
||||
let storage = self.storage.sum(self.layout(), sum_dims)?;
|
||||
let storage = self.storage()?.sum(self.layout(), sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
@ -695,8 +717,8 @@ impl Tensor {
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage
|
||||
.conv1d(self.layout(), &kernel.storage, kernel.layout(), ¶ms)?;
|
||||
self.storage()?
|
||||
.conv1d(self.layout(), &*kernel.storage()?, kernel.layout(), ¶ms)?;
|
||||
let op = if self.track_op() || kernel.track_op() {
|
||||
Some(Op::Conv1D {
|
||||
arg: self.clone(),
|
||||
@ -749,8 +771,8 @@ impl Tensor {
|
||||
})?
|
||||
}
|
||||
|
||||
let storage = self.storage.matmul(
|
||||
&rhs.storage,
|
||||
let storage = self.storage()?.matmul(
|
||||
&*rhs.storage()?,
|
||||
(batching, m, n, k),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
@ -769,11 +791,11 @@ impl Tensor {
|
||||
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
|
||||
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
||||
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
||||
let storage = self.storage.where_cond(
|
||||
let storage = self.storage()?.where_cond(
|
||||
self.layout(),
|
||||
&on_true.storage,
|
||||
&*on_true.storage()?,
|
||||
on_true.layout(),
|
||||
&on_false.storage,
|
||||
&*on_false.storage()?,
|
||||
on_false.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
||||
@ -821,8 +843,8 @@ impl Tensor {
|
||||
let seq_len = ids_shape.r1()?;
|
||||
let (_, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids
|
||||
.storage
|
||||
.embedding(ids.layout(), &rhs.storage, rhs.layout())?;
|
||||
.storage()?
|
||||
.embedding(ids.layout(), &*rhs.storage()?, rhs.layout())?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = if ids.track_op() || rhs.track_op() {
|
||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||
@ -836,23 +858,6 @@ impl Tensor {
|
||||
self.layout.strided_index()
|
||||
}
|
||||
|
||||
/// Returns data from the underlying storage, this does not take the strides
|
||||
/// into account so the size of the resulting buffer might be larger than the
|
||||
/// tensor number of elements.
|
||||
pub fn storage_data<S: crate::WithDType>(&self) -> Result<std::borrow::Cow<[S]>> {
|
||||
match self.storage.as_ref() {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let slice = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(std::borrow::Cow::Borrowed(slice))
|
||||
}
|
||||
Storage::Cuda(slice) => {
|
||||
let cpu_storage = slice.to_cpu_storage()?;
|
||||
let storage_data = S::cpu_storage_data(cpu_storage)?;
|
||||
Ok(std::borrow::Cow::Owned(storage_data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the data contained in a 1D tensor as a vector of scalar values.
|
||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||
if self.rank() != 1 {
|
||||
@ -862,7 +867,7 @@ impl Tensor {
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
match self.storage.as_ref() {
|
||||
match &*self.storage()? {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(self.strided_index().map(|i| data[i]).collect())
|
||||
@ -890,7 +895,7 @@ impl Tensor {
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(rows)
|
||||
};
|
||||
match self.storage.as_ref() {
|
||||
match &*self.storage()? {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
@ -914,7 +919,7 @@ impl Tensor {
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(top_rows)
|
||||
};
|
||||
match self.storage.as_ref() {
|
||||
match &*self.storage()? {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
@ -922,12 +927,12 @@ impl Tensor {
|
||||
|
||||
/// The dtype for the elements stored in the input tensor.
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.storage.dtype()
|
||||
self.dtype
|
||||
}
|
||||
|
||||
/// The device on which the input tensor is located.
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// The tensor shape, i.e. dimension sizes on each axis.
|
||||
@ -1114,6 +1119,8 @@ impl Tensor {
|
||||
layout: self.layout.transpose(dim1, dim2)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1133,10 +1140,12 @@ impl Tensor {
|
||||
pub fn copy(&self) -> Result<Tensor> {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(self.storage.try_clone(self.layout())?),
|
||||
storage: Arc::new(RefCell::new(self.storage()?.try_clone(self.layout())?)),
|
||||
layout: self.layout.clone(),
|
||||
op: None, // TODO
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1150,6 +1159,8 @@ impl Tensor {
|
||||
layout: self.layout.clone(),
|
||||
op: None,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1159,7 +1170,7 @@ impl Tensor {
|
||||
if self.device().same_device(device) {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let storage = match (self.storage.as_ref(), device) {
|
||||
let storage = match (&*self.storage()?, device) {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
@ -1179,10 +1190,12 @@ impl Tensor {
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(storage),
|
||||
storage: Arc::new(RefCell::new(storage)),
|
||||
layout: self.layout.clone(),
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1216,6 +1229,8 @@ impl Tensor {
|
||||
layout: self.layout.broadcast_as(shape)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -1240,7 +1255,7 @@ impl Tensor {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.to_dtype(self.layout(), dtype)?;
|
||||
let storage = self.storage()?.to_dtype(self.layout(), dtype)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::ToDType(self.clone()))
|
||||
} else {
|
||||
@ -1258,7 +1273,7 @@ impl Tensor {
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
self.storage
|
||||
self.storage()?
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(
|
||||
storage,
|
||||
@ -1307,11 +1322,13 @@ impl Tensor {
|
||||
layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
self.storage
|
||||
self.storage()?
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
@ -1507,11 +1524,15 @@ impl Tensor {
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage
|
||||
arg.storage()?
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
fn storage(&self) -> Result<std::cell::Ref<'_, Storage>> {
|
||||
Ok(self.storage.try_borrow()?)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -27,12 +27,18 @@ fn matmul_grad(device: &Device) -> Result<()> {
|
||||
assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3)));
|
||||
assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2)));
|
||||
assert_eq!(
|
||||
&*grad_x.storage_data::<f32>()?,
|
||||
&[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.]
|
||||
&*grad_x.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[1., 5., 9.], [1., 5., 9.]],
|
||||
[[13., 17., 21.], [13., 17., 21.]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
&*grad_y.storage_data::<f32>()?,
|
||||
&[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.]
|
||||
&*grad_y.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[3., 3.], [5., 5.], [7., 7.]],
|
||||
[[15., 15.], [17., 17.], [19., 19.]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ impl BertEmbeddings {
|
||||
if let Some(position_embeddings) = &self.position_embeddings {
|
||||
// TODO: Proper absolute positions?
|
||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||
let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?;
|
||||
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
||||
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
|
||||
}
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
|
@ -183,7 +183,7 @@ impl FalconRotaryEmbedding {
|
||||
past_kv_len: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, &query.device(), query.dtype())?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
|
||||
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
||||
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
||||
let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
|
||||
@ -194,7 +194,7 @@ impl FalconRotaryEmbedding {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
@ -471,7 +471,7 @@ impl Falcon {
|
||||
Some((k, _)) => k.dim(1)?,
|
||||
None => 0,
|
||||
};
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(&input_ids.device())?;
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
|
||||
}
|
||||
|
@ -227,7 +227,7 @@ impl CausalSelfAttention {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
@ -180,7 +180,7 @@ impl EncodecResidualVectorQuantizer {
|
||||
}
|
||||
|
||||
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, &codes.device())?;
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||
if codes.dim(0)? != self.layers.len() {
|
||||
anyhow::bail!(
|
||||
"codes shape {:?} does not match the number of quantization layers {}",
|
||||
|
@ -311,13 +311,13 @@ impl MusicgenDecoder {
|
||||
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
|
||||
let b_sz = b_sz_times_codebooks / self.num_codebooks;
|
||||
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
|
||||
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, &dev)?;
|
||||
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
|
||||
for (idx, codebook) in self.embed_tokens.iter().enumerate() {
|
||||
let inp = input.narrow(1, idx, 1)?.squeeze(1)?;
|
||||
inputs_embeds = (inputs_embeds + codebook.forward(&inp)?)?
|
||||
}
|
||||
let inputs_embeds = inputs_embeds;
|
||||
let positions = self.embed_positions.forward(&input)?.to_device(&dev)?;
|
||||
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
|
||||
let mut xs = inputs_embeds.broadcast_add(&positions)?;
|
||||
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
|
||||
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
|
||||
|
@ -109,7 +109,7 @@ impl Decoder {
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![SOT_TOKEN];
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
|
@ -50,7 +50,7 @@ enum PyDevice {
|
||||
}
|
||||
|
||||
impl PyDevice {
|
||||
fn from_device(device: Device) -> Self {
|
||||
fn from_device(device: &Device) -> Self {
|
||||
match device {
|
||||
Device::Cpu => Self::Cpu,
|
||||
Device::Cuda(_) => Self::Cuda,
|
||||
|
Reference in New Issue
Block a user