mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the casting operation.
This commit is contained in:
@ -13,12 +13,12 @@ pub enum CpuStorage {
|
|||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unary_map<T: Copy, F: FnMut(T) -> T>(
|
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
stride: &[usize],
|
stride: &[usize],
|
||||||
vs: &[T],
|
vs: &[T],
|
||||||
mut f: F,
|
mut f: F,
|
||||||
) -> Vec<T> {
|
) -> Vec<U> {
|
||||||
if shape.is_contiguous(stride) {
|
if shape.is_contiguous(stride) {
|
||||||
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
|
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
|
||||||
} else {
|
} else {
|
||||||
@ -105,6 +105,48 @@ impl CpuStorage {
|
|||||||
D::cpu_storage_as_mut_slice(self)
|
D::cpu_storage_as_mut_slice(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||||
|
// TODO: find a way around the quadratic number of cases below.
|
||||||
|
match (self, dtype) {
|
||||||
|
(Self::U32(storage), DType::F32) => {
|
||||||
|
let data = unary_map(shape, stride, storage, |v| v as f32);
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
(Self::F32(storage), DType::F32) => {
|
||||||
|
let data = unary_map(shape, stride, storage, |v| v);
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
(Self::F64(storage), DType::F32) => {
|
||||||
|
let data = unary_map(shape, stride, storage, |v| v as f32);
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
(Self::U32(storage), DType::U32) => {
|
||||||
|
let data = unary_map(shape, stride, storage, |v| v);
|
||||||
|
Ok(Self::U32(data))
|
||||||
|
}
|
||||||
|
(Self::F32(storage), DType::U32) => {
|
||||||
|
let data = unary_map(shape, stride, storage, |v| v as u32);
|
||||||
|
Ok(Self::U32(data))
|
||||||
|
}
|
||||||
|
(Self::F64(storage), DType::U32) => {
|
||||||
|
let data = unary_map(shape, stride, storage, |v| v as u32);
|
||||||
|
Ok(Self::U32(data))
|
||||||
|
}
|
||||||
|
(Self::U32(storage), DType::F64) => {
|
||||||
|
let data = unary_map(shape, stride, storage, |v| v as f64);
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
(Self::F32(storage), DType::F64) => {
|
||||||
|
let data = unary_map(shape, stride, storage, |v| v as f64);
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
(Self::F64(storage), DType::F64) => {
|
||||||
|
let data = unary_map(shape, stride, storage, |v| v);
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn affine_impl(
|
pub(crate) fn affine_impl(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
|
@ -242,6 +242,10 @@ impl CudaStorage {
|
|||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
||||||
|
Err(CudaError::InternalError("TODO: implement embedding"))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn affine_impl(
|
pub(crate) fn affine_impl(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
@ -400,7 +404,7 @@ impl CudaStorage {
|
|||||||
_hidden_size: usize,
|
_hidden_size: usize,
|
||||||
_vocab_size: usize,
|
_vocab_size: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!("Implement embedding for gpu");
|
Err(CudaError::InternalError("TODO: implement embedding"))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
|
@ -62,6 +62,10 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Shape, _: &[usize]) -> Result<Self> {
|
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Shape, _: &[usize]) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ pub(crate) enum Op {
|
|||||||
mul: f64,
|
mul: f64,
|
||||||
add: f64,
|
add: f64,
|
||||||
},
|
},
|
||||||
|
ToDType(Tensor),
|
||||||
Exp(Tensor),
|
Exp(Tensor),
|
||||||
Log(Tensor),
|
Log(Tensor),
|
||||||
Sin(Tensor),
|
Sin(Tensor),
|
||||||
|
@ -60,7 +60,6 @@ impl Storage {
|
|||||||
mul: f64,
|
mul: f64,
|
||||||
add: f64,
|
add: f64,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// TODO: Different code path for the contiguous case?
|
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||||
@ -73,6 +72,19 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||||
|
match self {
|
||||||
|
Storage::Cpu(storage) => {
|
||||||
|
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
Self::Cuda(storage) => {
|
||||||
|
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<B: op::UnaryOp>(
|
pub(crate) fn unary_impl<B: op::UnaryOp>(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
|
@ -602,6 +602,17 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
|
||||||
|
let shape = self.shape();
|
||||||
|
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::ToDType(self.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn contiguous(&self) -> Result<Tensor> {
|
pub fn contiguous(&self) -> Result<Tensor> {
|
||||||
if self.is_contiguous() {
|
if self.is_contiguous() {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
@ -773,6 +784,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
|
| Op::ToDType(node)
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
| Op::Softmax(node, _)
|
| Op::Softmax(node, _)
|
||||||
@ -892,6 +904,10 @@ impl Tensor {
|
|||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
|
Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
|
||||||
|
Op::ToDType(arg) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||||
|
}
|
||||||
Op::Affine { arg, mul, .. } => {
|
Op::Affine { arg, mul, .. } => {
|
||||||
let arg_grad = grad.affine(*mul, 0.)?;
|
let arg_grad = grad.affine(*mul, 0.)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
Reference in New Issue
Block a user