mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Merge pull request #6 from LaurentMazare/add_embedding
Adding embedding op (not generic gather, no select).
This commit is contained in:
@ -8,6 +8,7 @@ use gemm::{gemm, Parallelism};
|
||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CpuStorage {
|
||||
U32(Vec<u32>),
|
||||
F32(Vec<f32>),
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
@ -15,6 +16,7 @@ pub enum CpuStorage {
|
||||
impl CpuStorage {
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::U32(_) => DType::U32,
|
||||
Self::F32(_) => DType::F32,
|
||||
Self::F64(_) => DType::F64,
|
||||
}
|
||||
@ -36,6 +38,13 @@ impl CpuStorage {
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Self::U32(storage) => {
|
||||
let index = StridedIndex::new(shape.dims(), stride);
|
||||
let mul = mul as u32;
|
||||
let add = add as u32;
|
||||
let data = index.map(|i| storage[i] * mul + add).collect();
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let index = StridedIndex::new(shape.dims(), stride);
|
||||
let mul = mul as f32;
|
||||
@ -64,6 +73,9 @@ impl CpuStorage {
|
||||
let data = index.map(|i| B::f64(storage[i])).collect();
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
Self::U32(_storage) => {
|
||||
todo!("No unary for u32 because of neg, sqrt")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,6 +160,58 @@ impl CpuStorage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
CpuStorage::U32(lhs) => match rhs {
|
||||
CpuStorage::F32(rhs) => {
|
||||
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
||||
for &index in lhs {
|
||||
let index: usize = index.try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "embedding",
|
||||
});
|
||||
} else {
|
||||
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(CpuStorage::F32(weights))
|
||||
}
|
||||
CpuStorage::F64(rhs) => {
|
||||
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
||||
for &index in lhs {
|
||||
let index: usize = index.try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "embedding",
|
||||
});
|
||||
} else {
|
||||
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(CpuStorage::F64(weights))
|
||||
}
|
||||
rhs => Err(Error::UnexpectedDType {
|
||||
expected: DType::F32,
|
||||
got: rhs.dtype(),
|
||||
}),
|
||||
},
|
||||
lhs => Err(Error::UnexpectedDType {
|
||||
expected: DType::U32,
|
||||
got: lhs.dtype(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
@ -238,6 +302,10 @@ impl CpuStorage {
|
||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::U32 => {
|
||||
let data = vec![1u32; elem_count];
|
||||
Self::U32(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = vec![1f32; elem_count];
|
||||
Self::F32(data)
|
||||
@ -252,6 +320,10 @@ impl CpuStorage {
|
||||
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::U32 => {
|
||||
let data = vec![0u32; elem_count];
|
||||
Self::U32(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = vec![0f32; elem_count];
|
||||
Self::F32(data)
|
||||
|
Reference in New Issue
Block a user