mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Adding embedding op (not generic gather, no select).
This commit is contained in:
@ -8,6 +8,7 @@ use gemm::{gemm, Parallelism};
|
|||||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum CpuStorage {
|
pub enum CpuStorage {
|
||||||
|
U32(Vec<u32>),
|
||||||
F32(Vec<f32>),
|
F32(Vec<f32>),
|
||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
@ -15,6 +16,7 @@ pub enum CpuStorage {
|
|||||||
impl CpuStorage {
|
impl CpuStorage {
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
|
Self::U32(_) => DType::U32,
|
||||||
Self::F32(_) => DType::F32,
|
Self::F32(_) => DType::F32,
|
||||||
Self::F64(_) => DType::F64,
|
Self::F64(_) => DType::F64,
|
||||||
}
|
}
|
||||||
@ -36,6 +38,13 @@ impl CpuStorage {
|
|||||||
add: f64,
|
add: f64,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
|
Self::U32(storage) => {
|
||||||
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
|
let mul = mul as u32;
|
||||||
|
let add = add as u32;
|
||||||
|
let data = index.map(|i| storage[i] * mul + add).collect();
|
||||||
|
Ok(Self::U32(data))
|
||||||
|
}
|
||||||
Self::F32(storage) => {
|
Self::F32(storage) => {
|
||||||
let index = StridedIndex::new(shape.dims(), stride);
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
let mul = mul as f32;
|
let mul = mul as f32;
|
||||||
@ -64,6 +73,9 @@ impl CpuStorage {
|
|||||||
let data = index.map(|i| B::f64(storage[i])).collect();
|
let data = index.map(|i| B::f64(storage[i])).collect();
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
|
Self::U32(_storage) => {
|
||||||
|
todo!("No unary for u32 because of neg, sqrt")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,6 +150,57 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
||||||
|
pub(crate) fn embedding_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
hidden_size: usize,
|
||||||
|
vocab_size: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
match self {
|
||||||
|
CpuStorage::U32(lhs) => match rhs {
|
||||||
|
CpuStorage::F32(rhs) => {
|
||||||
|
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
||||||
|
for &index in lhs {
|
||||||
|
let index: usize = index.try_into()?;
|
||||||
|
if index >= vocab_size {
|
||||||
|
return Err(Error::InvalidIndex {
|
||||||
|
index,
|
||||||
|
vocab_size,
|
||||||
|
op: "embedding",
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(CpuStorage::F32(weights))
|
||||||
|
}
|
||||||
|
CpuStorage::F64(rhs) => {
|
||||||
|
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
||||||
|
for &index in lhs {
|
||||||
|
let index: usize = index.try_into()?;
|
||||||
|
if index >= vocab_size {
|
||||||
|
return Err(Error::InvalidIndex {
|
||||||
|
index,
|
||||||
|
vocab_size,
|
||||||
|
op: "embedding",
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(CpuStorage::F64(weights))
|
||||||
|
}
|
||||||
|
rhs => Err(Error::UnexpectedDType {
|
||||||
|
expected: DType::F32,
|
||||||
|
got: rhs.dtype(),
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
lhs => Err(Error::UnexpectedDType {
|
||||||
|
expected: DType::U32,
|
||||||
|
got: lhs.dtype(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
@ -230,6 +293,10 @@ impl CpuStorage {
|
|||||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
match dtype {
|
||||||
|
DType::U32 => {
|
||||||
|
let data = vec![1u32; elem_count];
|
||||||
|
Self::U32(data)
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = vec![1f32; elem_count];
|
let data = vec![1f32; elem_count];
|
||||||
Self::F32(data)
|
Self::F32(data)
|
||||||
@ -244,6 +311,10 @@ impl CpuStorage {
|
|||||||
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
|
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
match dtype {
|
||||||
|
DType::U32 => {
|
||||||
|
let data = vec![0u32; elem_count];
|
||||||
|
Self::U32(data)
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = vec![0f32; elem_count];
|
let data = vec![0f32; elem_count];
|
||||||
Self::F32(data)
|
Self::F32(data)
|
||||||
|
@ -346,6 +346,15 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embedding_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
hidden_size: usize,
|
||||||
|
vocab_size: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!("Implement embedding for gpu");
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
|
@ -2,6 +2,7 @@ use crate::{CpuStorage, Error, Result};
|
|||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum DType {
|
pub enum DType {
|
||||||
|
U32,
|
||||||
F32,
|
F32,
|
||||||
F64,
|
F64,
|
||||||
}
|
}
|
||||||
@ -9,6 +10,7 @@ pub enum DType {
|
|||||||
impl DType {
|
impl DType {
|
||||||
pub fn size_in_bytes(&self) -> usize {
|
pub fn size_in_bytes(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
|
Self::U32 => 4,
|
||||||
Self::F32 => 4,
|
Self::F32 => 4,
|
||||||
Self::F64 => 8,
|
Self::F64 => 8,
|
||||||
}
|
}
|
||||||
@ -70,5 +72,6 @@ macro_rules! with_dtype {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
with_dtype!(u32, U32);
|
||||||
with_dtype!(f32, F32);
|
with_dtype!(f32, F32);
|
||||||
with_dtype!(f64, F64);
|
with_dtype!(f64, F64);
|
||||||
|
@ -76,6 +76,10 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embedding_impl(&self, _: &Self, _: usize, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
&self,
|
&self,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
|
10
src/error.rs
10
src/error.rs
@ -15,6 +15,13 @@ pub enum Error {
|
|||||||
#[error("backward is not supported for {op}")]
|
#[error("backward is not supported for {op}")]
|
||||||
BackwardNotSupported { op: &'static str },
|
BackwardNotSupported { op: &'static str },
|
||||||
|
|
||||||
|
#[error("{op} invalid index {index} with vocab {vocab_size}")]
|
||||||
|
InvalidIndex {
|
||||||
|
op: &'static str,
|
||||||
|
index: usize,
|
||||||
|
vocab_size: usize,
|
||||||
|
},
|
||||||
|
|
||||||
#[error("the candle crate has not been built with cuda support")]
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
NotCompiledWithCudaSupport,
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
@ -65,6 +72,9 @@ pub enum Error {
|
|||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(#[from] crate::CudaError),
|
Cuda(#[from] crate::CudaError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
@ -7,6 +7,7 @@ pub(crate) enum Op {
|
|||||||
Sub(Tensor, Tensor),
|
Sub(Tensor, Tensor),
|
||||||
Div(Tensor, Tensor),
|
Div(Tensor, Tensor),
|
||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
|
Embedding(Tensor, Tensor),
|
||||||
|
|
||||||
Cat(Vec<Tensor>, usize),
|
Cat(Vec<Tensor>, usize),
|
||||||
|
|
||||||
|
@ -122,6 +122,31 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embedding_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
hidden_size: usize,
|
||||||
|
vocab_size: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(rhs, "matmul")?;
|
||||||
|
self.same_dtype(rhs, "matmul")?;
|
||||||
|
match (self, rhs) {
|
||||||
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
|
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||||
|
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "embedding",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
|
@ -344,6 +344,33 @@ impl Tensor {
|
|||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||||
|
if !rhs.is_contiguous() {
|
||||||
|
return Err(Error::RequiresContiguous { op: "embedding" });
|
||||||
|
} else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 {
|
||||||
|
return Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: ids.shape.clone(),
|
||||||
|
rhs: rhs.shape.clone(),
|
||||||
|
op: "embedding",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let seq_len = ids.shape().r1()?;
|
||||||
|
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||||
|
let storage = ids
|
||||||
|
.storage
|
||||||
|
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?;
|
||||||
|
let shape: Shape = (seq_len, hidden_size).into();
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage,
|
||||||
|
shape: shape.clone(),
|
||||||
|
stride: shape.stride_contiguous(),
|
||||||
|
op: Some(Op::Embedding(ids.clone(), rhs.clone())),
|
||||||
|
is_variable: false,
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||||
crate::StridedIndex::new(self.dims(), self.stride())
|
crate::StridedIndex::new(self.dims(), self.stride())
|
||||||
}
|
}
|
||||||
@ -740,6 +767,7 @@ impl Tensor {
|
|||||||
| Op::Mul(lhs, rhs)
|
| Op::Mul(lhs, rhs)
|
||||||
| Op::Sub(lhs, rhs)
|
| Op::Sub(lhs, rhs)
|
||||||
| Op::Div(lhs, rhs)
|
| Op::Div(lhs, rhs)
|
||||||
|
| Op::Embedding(lhs, rhs)
|
||||||
| Op::Matmul(lhs, rhs) => {
|
| Op::Matmul(lhs, rhs) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
@ -830,6 +858,9 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
|
Op::Embedding(_lhs, _rhs) => {
|
||||||
|
todo!("Backward for embedding not implemented");
|
||||||
|
}
|
||||||
Op::Matmul(lhs, rhs) => {
|
Op::Matmul(lhs, rhs) => {
|
||||||
// Skipping checks, the op went ok, we can skip
|
// Skipping checks, the op went ok, we can skip
|
||||||
// the matmul size checks for now.
|
// the matmul size checks for now.
|
||||||
|
Reference in New Issue
Block a user