Remove the embedding ops in favor of index-select. (#299)

* Remove the embedding ops in favor of index-select.

* Also remove the cuda kernels.
This commit is contained in:
Laurent Mazare
2023-08-02 05:42:11 +01:00
committed by GitHub
parent cc76c63202
commit 4b3bd79fbd
11 changed files with 11 additions and 209 deletions

View File

@ -37,7 +37,6 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,

View File

@ -59,7 +59,6 @@ impl Tensor {
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
@ -188,9 +187,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
}
Op::Embedding(_lhs, _rhs) => {
Err(Error::BackwardNotSupported { op: "embedding" })?
}
Op::Matmul(lhs, rhs) => {
// Skipping checks, the op went ok, we can skip
// the matmul size checks for now.

View File

@ -861,58 +861,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
}
}
struct Embedding<'a, I: IntDType> {
vocab_size: usize,
hidden_size: usize,
ids: &'a [I],
ids_l: &'a Layout,
}
impl<'a, I: IntDType> Map1 for Embedding<'a, I> {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
if !layout.is_contiguous() {
Err(Error::RequiresContiguous { op: "embedding" })?
}
let vs = &vs[layout.start_offset()..];
let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
match self.ids_l.contiguous_offsets() {
Some((o1, o2)) => {
for index in self.ids[o1..o2].iter() {
let index = index.as_usize();
if index >= self.vocab_size {
Err(Error::InvalidIndex {
index,
size: self.vocab_size,
op: "take",
}
.bt())?
} else {
let hidden_size = self.hidden_size;
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
}
None => {
for index in self.ids_l.strided_index() {
let index = self.ids[index].as_usize();
if index >= self.vocab_size {
Err(Error::InvalidIndex {
index,
size: self.vocab_size,
op: "take",
}
.bt())?
} else {
let hidden_size = self.hidden_size;
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
}
}
Ok(values)
}
}
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
match src_l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
@ -1664,27 +1612,6 @@ impl BackendStorage for CpuStorage {
Conv1D(params).map(self, l, kernel, kernel_l)
}
fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let (vocab_size, hidden_size) = rhs_l.shape().dims2()?;
match self {
Self::U8(ids) => Embedding {
vocab_size,
hidden_size,
ids,
ids_l,
}
.map(rhs, rhs_l),
Self::U32(ids) => Embedding {
vocab_size,
hidden_size,
ids,
ids_l,
}
.map(rhs, rhs_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")),
}
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),

View File

@ -690,46 +690,6 @@ impl<U: UnaryOpT> Map1 for U {
}
}
struct Embedding<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map1 for Embedding<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
rhs: &CudaSlice<T>,
dev: &CudaDevice,
rhs_l: &Layout,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let (name, ids) = match &self.0.slice {
CudaStorageSlice::U32(slice) => {
("emb_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::U8(slice) => {
("emb_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u8 or u32",
expected: DType::U32,
got: self.0.dtype(),
})
.w()?,
};
let shape = ids_l.shape();
let (v_size, h_size) = rhs_l.shape().dims2()?;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map1 for IndexSelect<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1421,12 +1381,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
Ok(Self { slice, device })
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;

View File

@ -75,9 +75,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -65,7 +65,6 @@ pub enum Op {
// The third argument is the reduced shape with `keepdim=true`.
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),

View File

@ -295,26 +295,6 @@ impl Storage {
}
}
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
self.same_device(rhs, "embedding")?;
match (self, rhs) {
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.embedding(layout, rhs, rhs_l)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.embedding(layout, rhs, rhs_l)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "embedding",
}
.bt()),
}
}
pub(crate) fn gather(
&self,
l: &Layout,

View File

@ -842,45 +842,35 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
/// Returns a tensor with the values from the `rhs` tensor at the index corresponding to the
/// Returns a tensor with the values from the `self` tensor at the index corresponding to the
/// values hold in the `ids` tensor.
///
/// # Arguments
///
/// * `self` - A tensor with dimensions `v, h`.
/// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
/// * `rhs` - A tensor with dimensions `v, h`.
///
/// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
/// vocabulary size, and `h` the hidden size.
///
/// ```rust
/// use candle::{Tensor, Device};
/// let rhs = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
/// let emb = Tensor::embedding(&ids, &rhs)?;
/// let emb = values.embedding(&ids)?;
/// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
if !rhs.is_contiguous() {
Err(Error::RequiresContiguous { op: "embedding" }.bt())?
} else if rhs.rank() != 2 || ids.rank() != 1 {
pub fn embedding(&self, ids: &Self) -> Result<Self> {
if self.rank() != 2 || ids.rank() != 1 {
Err(Error::ShapeMismatchBinaryOp {
lhs: ids.shape().clone(),
rhs: rhs.shape().clone(),
lhs: self.shape().clone(),
rhs: ids.shape().clone(),
op: "embedding",
}
.bt())?
}
let ids_shape = ids.shape();
let seq_len = ids_shape.dims1()?;
let (_, hidden_size) = rhs.dims2()?;
let storage = ids
.storage()
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
let shape: Shape = (seq_len, hidden_size).into();
let op = BackpropOp::new2(ids, rhs, Op::Embedding);
Ok(from_storage(storage, shape, op, false))
self.index_select(ids, 0)
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {

View File

@ -534,7 +534,7 @@ fn cat(device: &Device) -> Result<()> {
fn embeddings(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let hs = Tensor::embedding(&ids, &t)?;
let hs = t.embedding(&ids)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);

View File

@ -142,7 +142,7 @@ impl EncodecEuclideanCodebook {
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = Tensor::embedding(embed_ind, &self.embed)?;
let quantize = self.embed.embedding(embed_ind)?;
Ok(quantize)
}
}

View File

@ -3,32 +3,6 @@
#include "cuda_utils.cuh"
#include<stdint.h>
#define EMB_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const INDEX_TYPENAME *ids, \
const TYPENAME *inp, \
TYPENAME *out, \
const size_t h_size, \
const size_t v_size \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
memcpy(&out[i * h_size], &inp[ids[i] * h_size], h_size * sizeof(TYPENAME)); \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
memcpy(&out[i * h_size], &inp[ids[strided_i] * h_size], h_size * sizeof(TYPENAME)); \
} \
} \
} \
template<typename T, typename I>
__device__ void index_select(
const size_t numel,
@ -177,8 +151,6 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
EMB_OP(__nv_bfloat16, uint32_t, emb_u32_bf16)
EMB_OP(__nv_bfloat16, uint8_t, emb_u8_bf16)
IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16)
GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16)
@ -190,8 +162,6 @@ SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
#endif
#if __CUDA_ARCH__ >= 530
EMB_OP(__half, uint32_t, emb_u32_f16)
EMB_OP(__half, uint8_t, emb_u8_f16)
IS_OP(__half, uint32_t, is_u32_f16)
IS_OP(__half, uint8_t, is_u8_f16)
GATHER_OP(__half, uint32_t, gather_u32_f16)
@ -202,16 +172,6 @@ SA_OP(__half, uint32_t, sa_u32_f16)
SA_OP(__half, uint8_t, sa_u8_f16)
#endif
EMB_OP(float, uint32_t, emb_u32_f32)
EMB_OP(double, uint32_t, emb_u32_f64)
EMB_OP(uint8_t, uint32_t, emb_u32_u8)
EMB_OP(uint32_t, uint32_t, emb_u32_u32)
EMB_OP(float, uint8_t, emb_u8_f32)
EMB_OP(double, uint8_t, emb_u8_f64)
EMB_OP(uint8_t, uint8_t, emb_u8_u8)
EMB_OP(uint32_t, uint8_t, emb_u8_u32)
IS_OP(float, uint32_t, is_u32_f32)
IS_OP(double, uint32_t, is_u32_f64)
IS_OP(uint8_t, uint32_t, is_u32_u8)