mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the casting operation.
This commit is contained in:
@ -13,12 +13,12 @@ pub enum CpuStorage {
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, F: FnMut(T) -> T>(
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<T> {
|
||||
) -> Vec<U> {
|
||||
if shape.is_contiguous(stride) {
|
||||
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
|
||||
} else {
|
||||
@ -105,6 +105,48 @@ impl CpuStorage {
|
||||
D::cpu_storage_as_mut_slice(self)
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
// TODO: find a way around the quadratic number of cases below.
|
||||
match (self, dtype) {
|
||||
(Self::U32(storage), DType::F32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::U32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::U32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F64) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F64) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F64) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
|
@ -242,6 +242,10 @@ impl CudaStorage {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement embedding"))
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
@ -400,7 +404,7 @@ impl CudaStorage {
|
||||
_hidden_size: usize,
|
||||
_vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
todo!("Implement embedding for gpu");
|
||||
Err(CudaError::InternalError("TODO: implement embedding"))
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
|
@ -62,6 +62,10 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Shape, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ pub(crate) enum Op {
|
||||
mul: f64,
|
||||
add: f64,
|
||||
},
|
||||
ToDType(Tensor),
|
||||
Exp(Tensor),
|
||||
Log(Tensor),
|
||||
Sin(Tensor),
|
||||
|
@ -60,7 +60,6 @@ impl Storage {
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||
@ -73,6 +72,19 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOp>(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
|
@ -602,6 +602,17 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::ToDType(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
|
||||
pub fn contiguous(&self) -> Result<Tensor> {
|
||||
if self.is_contiguous() {
|
||||
Ok(self.clone())
|
||||
@ -773,6 +784,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
@ -892,6 +904,10 @@ impl Tensor {
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
}
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
let arg_grad = grad.affine(*mul, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user