mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Support for "unbatched" rope. (#2926)
* Support for (un)-batched rope. * Use 3d rope in the rope/ropei/rope_thd functions. * Get the CPU versions to work. * Fix the cuda version. * Adapt the metal side. * Fix the metal tests.
This commit is contained in:
@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) {
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (2 * idx >= bh * td) return;
|
if (2 * idx >= bh * td) return;
|
||||||
|
|
||||||
uint32_t rope_idx = idx % (td / 2);
|
uint32_t rope_idx = idx % (td / 2);
|
||||||
|
if (stride_b > 0) {
|
||||||
|
uint32_t b_idx = (2 * idx) / stride_b;
|
||||||
|
rope_idx += b_idx * (td / 2);
|
||||||
|
}
|
||||||
T c = cos[rope_idx];
|
T c = cos[rope_idx];
|
||||||
T s = sin[rope_idx];
|
T s = sin[rope_idx];
|
||||||
|
|
||||||
@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) {
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (2 * idx >= bh * td) return;
|
if (2 * idx >= bh * td) return;
|
||||||
|
|
||||||
@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
|
|||||||
uint32_t i1 = i_bh * td + i_t * d + i_d;
|
uint32_t i1 = i_bh * td + i_t * d + i_d;
|
||||||
uint32_t i2 = i1 + d / 2;
|
uint32_t i2 = i1 + d / 2;
|
||||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||||
|
if (stride_b > 0) {
|
||||||
|
uint32_t b_idx = (2 * idx) / stride_b;
|
||||||
|
i_cs += b_idx * (td / 2);
|
||||||
|
}
|
||||||
T c = cos[i_cs];
|
T c = cos[i_cs];
|
||||||
T s = sin[i_cs];
|
T s = sin[i_cs];
|
||||||
|
|
||||||
@ -259,7 +267,8 @@ __device__ void rope_thd(
|
|||||||
const uint32_t b,
|
const uint32_t b,
|
||||||
const uint32_t t,
|
const uint32_t t,
|
||||||
const uint32_t h,
|
const uint32_t h,
|
||||||
const uint32_t d
|
const uint32_t d,
|
||||||
|
const uint32_t stride_b
|
||||||
) {
|
) {
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (2 * idx >= b * t * h * d) return;
|
if (2 * idx >= b * t * h * d) return;
|
||||||
@ -270,6 +279,10 @@ __device__ void rope_thd(
|
|||||||
uint32_t i1 = i_bth * d + i_d;
|
uint32_t i1 = i_bth * d + i_d;
|
||||||
uint32_t i2 = i1 + d / 2;
|
uint32_t i2 = i1 + d / 2;
|
||||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||||
|
if (stride_b > 0) {
|
||||||
|
uint32_t b_idx = (2 * idx) / stride_b;
|
||||||
|
i_cs += b_idx * ((t * d) / 2);
|
||||||
|
}
|
||||||
T c = cos[i_cs];
|
T c = cos[i_cs];
|
||||||
T s = sin[i_cs];
|
T s = sin[i_cs];
|
||||||
|
|
||||||
@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
const TYPENAME *sin, \
|
const TYPENAME *sin, \
|
||||||
TYPENAME *dst, \
|
TYPENAME *dst, \
|
||||||
const uint32_t bh, \
|
const uint32_t bh, \
|
||||||
const uint32_t td) { \
|
const uint32_t td, \
|
||||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
const uint32_t stride_b) { \
|
||||||
|
ropei<TYPENAME>(src, cos, sin, dst, bh, td, stride_b); \
|
||||||
} \
|
} \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const TYPENAME *src, \
|
const TYPENAME *src, \
|
||||||
@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
TYPENAME *dst, \
|
TYPENAME *dst, \
|
||||||
const uint32_t bh, \
|
const uint32_t bh, \
|
||||||
const uint32_t td, \
|
const uint32_t td, \
|
||||||
const uint32_t d) { \
|
const uint32_t d, \
|
||||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
|
const uint32_t stride_b) { \
|
||||||
|
rope<TYPENAME>(src, cos, sin, dst, bh, td, d, stride_b); \
|
||||||
} \
|
} \
|
||||||
extern "C" __global__ void FN_NAME_THD( \
|
extern "C" __global__ void FN_NAME_THD( \
|
||||||
const TYPENAME *src, \
|
const TYPENAME *src, \
|
||||||
@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
const uint32_t b, \
|
const uint32_t b, \
|
||||||
const uint32_t t, \
|
const uint32_t t, \
|
||||||
const uint32_t h, \
|
const uint32_t h, \
|
||||||
const uint32_t d) { \
|
const uint32_t d, \
|
||||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
|
const uint32_t stride_b) { \
|
||||||
|
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d, stride_b); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
|
@ -991,6 +991,7 @@ pub fn call_rope_i(
|
|||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
bh: usize,
|
bh: usize,
|
||||||
td: usize,
|
td: usize,
|
||||||
|
stride_b: usize,
|
||||||
src: &Buffer,
|
src: &Buffer,
|
||||||
src_offset: usize,
|
src_offset: usize,
|
||||||
cos: &Buffer,
|
cos: &Buffer,
|
||||||
@ -1009,6 +1010,7 @@ pub fn call_rope_i(
|
|||||||
(
|
(
|
||||||
bh,
|
bh,
|
||||||
td,
|
td,
|
||||||
|
stride_b,
|
||||||
(src, src_offset),
|
(src, src_offset),
|
||||||
(cos, cos_offset),
|
(cos, cos_offset),
|
||||||
(sin, sin_offset),
|
(sin, sin_offset),
|
||||||
@ -1034,6 +1036,7 @@ pub fn call_rope_thd(
|
|||||||
t: usize,
|
t: usize,
|
||||||
h: usize,
|
h: usize,
|
||||||
d: usize,
|
d: usize,
|
||||||
|
stride_b: usize,
|
||||||
src: &Buffer,
|
src: &Buffer,
|
||||||
src_offset: usize,
|
src_offset: usize,
|
||||||
cos: &Buffer,
|
cos: &Buffer,
|
||||||
@ -1054,6 +1057,7 @@ pub fn call_rope_thd(
|
|||||||
t,
|
t,
|
||||||
h,
|
h,
|
||||||
d,
|
d,
|
||||||
|
stride_b,
|
||||||
(src, src_offset),
|
(src, src_offset),
|
||||||
(cos, cos_offset),
|
(cos, cos_offset),
|
||||||
(sin, sin_offset),
|
(sin, sin_offset),
|
||||||
@ -1078,6 +1082,7 @@ pub fn call_rope(
|
|||||||
bh: usize,
|
bh: usize,
|
||||||
td: usize,
|
td: usize,
|
||||||
d: usize,
|
d: usize,
|
||||||
|
stride_b: usize,
|
||||||
src: &Buffer,
|
src: &Buffer,
|
||||||
src_offset: usize,
|
src_offset: usize,
|
||||||
cos: &Buffer,
|
cos: &Buffer,
|
||||||
@ -1097,6 +1102,7 @@ pub fn call_rope(
|
|||||||
bh,
|
bh,
|
||||||
td,
|
td,
|
||||||
d,
|
d,
|
||||||
|
stride_b,
|
||||||
(src, src_offset),
|
(src, src_offset),
|
||||||
(cos, cos_offset),
|
(cos, cos_offset),
|
||||||
(sin, sin_offset),
|
(sin, sin_offset),
|
||||||
|
@ -1097,6 +1097,7 @@ template<typename T>
|
|||||||
METAL_FUNC void ropei(
|
METAL_FUNC void ropei(
|
||||||
constant size_t &bh,
|
constant size_t &bh,
|
||||||
constant size_t &td,
|
constant size_t &td,
|
||||||
|
constant size_t &stride_b,
|
||||||
device const T *src,
|
device const T *src,
|
||||||
device const T *cos,
|
device const T *cos,
|
||||||
device const T *sin,
|
device const T *sin,
|
||||||
@ -1107,6 +1108,10 @@ METAL_FUNC void ropei(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
size_t rope_idx = tid % (td / 2);
|
size_t rope_idx = tid % (td / 2);
|
||||||
|
if (stride_b > 0) {
|
||||||
|
size_t b_idx = (2 * tid) / stride_b;
|
||||||
|
rope_idx += b_idx * (td / 2);
|
||||||
|
}
|
||||||
T c = cos[rope_idx];
|
T c = cos[rope_idx];
|
||||||
T s = sin[rope_idx];
|
T s = sin[rope_idx];
|
||||||
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
|
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
|
||||||
@ -1118,6 +1123,7 @@ METAL_FUNC void rope(
|
|||||||
constant size_t &bh,
|
constant size_t &bh,
|
||||||
constant size_t &td,
|
constant size_t &td,
|
||||||
constant size_t &d,
|
constant size_t &d,
|
||||||
|
constant size_t &stride_b,
|
||||||
device const T *src,
|
device const T *src,
|
||||||
device const T *cos,
|
device const T *cos,
|
||||||
device const T *sin,
|
device const T *sin,
|
||||||
@ -1134,6 +1140,10 @@ METAL_FUNC void rope(
|
|||||||
size_t i1 = i_bh * td + i_t * d + i_d;
|
size_t i1 = i_bh * td + i_t * d + i_d;
|
||||||
size_t i2 = i1 + d / 2;
|
size_t i2 = i1 + d / 2;
|
||||||
size_t i_cs = i_t * (d / 2) + i_d;
|
size_t i_cs = i_t * (d / 2) + i_d;
|
||||||
|
if (stride_b > 0) {
|
||||||
|
size_t b_idx = (2 * idx) / stride_b;
|
||||||
|
i_cs += b_idx * (td / 2);
|
||||||
|
}
|
||||||
T c = cos[i_cs];
|
T c = cos[i_cs];
|
||||||
T s = sin[i_cs];
|
T s = sin[i_cs];
|
||||||
dst[i1] = src[i1] * c - src[i2] * s;
|
dst[i1] = src[i1] * c - src[i2] * s;
|
||||||
@ -1146,6 +1156,7 @@ METAL_FUNC void rope_thd(
|
|||||||
constant size_t &t,
|
constant size_t &t,
|
||||||
constant size_t &h,
|
constant size_t &h,
|
||||||
constant size_t &d,
|
constant size_t &d,
|
||||||
|
constant size_t &stride_b,
|
||||||
device const T *src,
|
device const T *src,
|
||||||
device const T *cos,
|
device const T *cos,
|
||||||
device const T *sin,
|
device const T *sin,
|
||||||
@ -1160,8 +1171,12 @@ METAL_FUNC void rope_thd(
|
|||||||
const size_t i_t = (i_bth / h) % t;
|
const size_t i_t = (i_bth / h) % t;
|
||||||
const size_t i1 = i_bth * d + i_d;
|
const size_t i1 = i_bth * d + i_d;
|
||||||
const size_t i2 = i1 + d / 2;
|
const size_t i2 = i1 + d / 2;
|
||||||
const size_t i_cs = i_t * (d / 2) + i_d;
|
size_t i_cs = i_t * (d / 2) + i_d;
|
||||||
T c = cos[i_cs];
|
if (stride_b > 0) {
|
||||||
|
const size_t b_idx = (2 * idx) / stride_b;
|
||||||
|
i_cs += b_idx * ((t * d) / 2);
|
||||||
|
}
|
||||||
|
T c = cos[i_cs];
|
||||||
T s = sin[i_cs];
|
T s = sin[i_cs];
|
||||||
dst[i1] = src[i1] * c - src[i2] * s;
|
dst[i1] = src[i1] * c - src[i2] * s;
|
||||||
dst[i2] = src[i1] * s + src[i2] * c;
|
dst[i2] = src[i1] * s + src[i2] * c;
|
||||||
@ -1171,38 +1186,41 @@ METAL_FUNC void rope_thd(
|
|||||||
kernel void FN_NAME_I( \
|
kernel void FN_NAME_I( \
|
||||||
constant size_t &bh, \
|
constant size_t &bh, \
|
||||||
constant size_t &td, \
|
constant size_t &td, \
|
||||||
|
constant size_t &stride_b, \
|
||||||
device const TYPENAME *src, \
|
device const TYPENAME *src, \
|
||||||
device const TYPENAME *cos, \
|
device const TYPENAME *cos, \
|
||||||
device const TYPENAME *sin, \
|
device const TYPENAME *sin, \
|
||||||
device TYPENAME *dst, \
|
device TYPENAME *dst, \
|
||||||
uint tid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
|
ropei<TYPENAME>(bh, td, stride_b, src, cos, sin, dst, tid); \
|
||||||
}\
|
}\
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
constant size_t &bh, \
|
constant size_t &bh, \
|
||||||
constant size_t &td, \
|
constant size_t &td, \
|
||||||
constant size_t &d, \
|
constant size_t &d, \
|
||||||
|
constant size_t &stride_b, \
|
||||||
device const TYPENAME *src, \
|
device const TYPENAME *src, \
|
||||||
device const TYPENAME *cos, \
|
device const TYPENAME *cos, \
|
||||||
device const TYPENAME *sin, \
|
device const TYPENAME *sin, \
|
||||||
device TYPENAME *dst, \
|
device TYPENAME *dst, \
|
||||||
uint idx [[ thread_position_in_grid ]] \
|
uint idx [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
|
rope<TYPENAME>(bh, td, d, stride_b, src, cos, sin, dst, idx); \
|
||||||
}\
|
}\
|
||||||
kernel void FN_NAME_THD( \
|
kernel void FN_NAME_THD( \
|
||||||
constant size_t &b, \
|
constant size_t &b, \
|
||||||
constant size_t &t, \
|
constant size_t &t, \
|
||||||
constant size_t &h, \
|
constant size_t &h, \
|
||||||
constant size_t &d, \
|
constant size_t &d, \
|
||||||
|
constant size_t &stride_b, \
|
||||||
device const TYPENAME *src, \
|
device const TYPENAME *src, \
|
||||||
device const TYPENAME *cos, \
|
device const TYPENAME *cos, \
|
||||||
device const TYPENAME *sin, \
|
device const TYPENAME *sin, \
|
||||||
device TYPENAME *dst, \
|
device TYPENAME *dst, \
|
||||||
uint idx [[ thread_position_in_grid ]] \
|
uint idx [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
|
rope_thd<TYPENAME>(b, t, h, d, stride_b, src, cos, sin, dst, idx); \
|
||||||
}\
|
}\
|
||||||
|
|
||||||
RMSNORM(rmsnorm_f32, float)
|
RMSNORM(rmsnorm_f32, float)
|
||||||
|
@ -1584,7 +1584,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
dim,
|
dim,
|
||||||
BufferOffset::zero_offset(&input_buffer),
|
BufferOffset::zero_offset(&input_buffer),
|
||||||
BufferOffset::zero_offset(&ids_buffer),
|
BufferOffset::zero_offset(&ids_buffer),
|
||||||
&output,
|
BufferOffset::zero_offset(&output),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
|
@ -46,15 +46,23 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
Some((o1, o2)) => &sin[o1..o2],
|
Some((o1, o2)) => &sin[o1..o2],
|
||||||
};
|
};
|
||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
|
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
|
||||||
let el_count = b * h * t * d;
|
let el_count = b * h * t * d;
|
||||||
let mut dst = vec![T::zero(); el_count];
|
let mut dst = vec![T::zero(); el_count];
|
||||||
src.par_chunks(t * d)
|
src.par_chunks(t * d)
|
||||||
.zip(dst.par_chunks_mut(t * d))
|
.zip(dst.par_chunks_mut(t * d))
|
||||||
.for_each(|(src, dst)| {
|
.enumerate()
|
||||||
|
.for_each(|(bh_i, (src, dst))| {
|
||||||
for i_over_2 in 0..t * d / 2 {
|
for i_over_2 in 0..t * d / 2 {
|
||||||
let i = 2 * i_over_2;
|
let i = 2 * i_over_2;
|
||||||
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
|
let rope_i = if unbatched_rope {
|
||||||
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
|
let b_i = bh_i / h;
|
||||||
|
i_over_2 + b_i * t * d / 2
|
||||||
|
} else {
|
||||||
|
i_over_2
|
||||||
|
};
|
||||||
|
dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i];
|
||||||
|
dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||||
@ -115,6 +123,11 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
Some((o1, o2)) => sin.slice(o1..o2),
|
Some((o1, o2)) => sin.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
|
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||||
|
(h * t * d) as u32
|
||||||
|
} else {
|
||||||
|
0u32
|
||||||
|
};
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||||
@ -125,7 +138,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
builder.arg(&cos);
|
builder.arg(&cos);
|
||||||
builder.arg(&sin);
|
builder.arg(&sin);
|
||||||
builder.arg(&dst);
|
builder.arg(&dst);
|
||||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
|
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
@ -182,6 +195,11 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
|
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
|
||||||
};
|
};
|
||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
|
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||||
|
h * t * d
|
||||||
|
} else {
|
||||||
|
0usize
|
||||||
|
};
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
||||||
candle_metal_kernels::call_rope_i(
|
candle_metal_kernels::call_rope_i(
|
||||||
@ -191,6 +209,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
name,
|
name,
|
||||||
b * h,
|
b * h,
|
||||||
t * d,
|
t * d,
|
||||||
|
stride_b,
|
||||||
src.buffer(),
|
src.buffer(),
|
||||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||||
cos.buffer(),
|
cos.buffer(),
|
||||||
@ -205,10 +224,23 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> {
|
||||||
|
match *cs.dims() {
|
||||||
|
[t, d] => Ok((t, d)),
|
||||||
|
[b, t, d] => {
|
||||||
|
if b != b_sz {
|
||||||
|
candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",)
|
||||||
|
}
|
||||||
|
Ok((t, d))
|
||||||
|
}
|
||||||
|
_ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
|
||||||
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
|
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
|
||||||
if cos_n_embd * 2 != n_embd
|
if cos_n_embd * 2 != n_embd
|
||||||
|| sin_n_embd * 2 != n_embd
|
|| sin_n_embd * 2 != n_embd
|
||||||
|| seq_len > cos_seq_len
|
|| seq_len > cos_seq_len
|
||||||
@ -292,16 +324,24 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
Some((o1, o2)) => &sin[o1..o2],
|
Some((o1, o2)) => &sin[o1..o2],
|
||||||
};
|
};
|
||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
|
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
|
||||||
let el_count = b * h * t * d;
|
let el_count = b * h * t * d;
|
||||||
let mut dst = vec![T::zero(); el_count];
|
let mut dst = vec![T::zero(); el_count];
|
||||||
src.par_chunks(t * d)
|
src.par_chunks(t * d)
|
||||||
.zip(dst.par_chunks_mut(t * d))
|
.zip(dst.par_chunks_mut(t * d))
|
||||||
.for_each(|(src, dst)| {
|
.enumerate()
|
||||||
|
.for_each(|(bh_i, (src, dst))| {
|
||||||
for i_t in 0..t {
|
for i_t in 0..t {
|
||||||
for i_d in 0..d / 2 {
|
for i_d in 0..d / 2 {
|
||||||
let i1 = i_t * d + i_d;
|
let i1 = i_t * d + i_d;
|
||||||
let i2 = i1 + d / 2;
|
let i2 = i1 + d / 2;
|
||||||
let i_cs = i_t * (d / 2) + i_d;
|
let i_cs = i_t * (d / 2) + i_d;
|
||||||
|
let i_cs = if unbatched_rope {
|
||||||
|
let b_i = bh_i / h;
|
||||||
|
i_cs + b_i * t * d / 2
|
||||||
|
} else {
|
||||||
|
i_cs
|
||||||
|
};
|
||||||
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
|
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
|
||||||
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
|
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
|
||||||
}
|
}
|
||||||
@ -365,6 +405,11 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
Some((o1, o2)) => sin.slice(o1..o2),
|
Some((o1, o2)) => sin.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
|
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||||
|
(h * t * d) as u32
|
||||||
|
} else {
|
||||||
|
0u32
|
||||||
|
};
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||||
@ -375,7 +420,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
builder.arg(&cos);
|
builder.arg(&cos);
|
||||||
builder.arg(&sin);
|
builder.arg(&sin);
|
||||||
builder.arg(&dst);
|
builder.arg(&dst);
|
||||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
|
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
@ -432,6 +477,11 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
|
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
|
||||||
};
|
};
|
||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
|
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||||
|
h * t * d
|
||||||
|
} else {
|
||||||
|
0usize
|
||||||
|
};
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
||||||
candle_metal_kernels::call_rope(
|
candle_metal_kernels::call_rope(
|
||||||
@ -442,6 +492,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
b * h,
|
b * h,
|
||||||
t * d,
|
t * d,
|
||||||
d,
|
d,
|
||||||
|
stride_b,
|
||||||
src.buffer(),
|
src.buffer(),
|
||||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||||
cos.buffer(),
|
cos.buffer(),
|
||||||
@ -457,9 +508,9 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
|
||||||
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
|
||||||
if cos_n_embd * 2 != n_embd
|
if cos_n_embd * 2 != n_embd
|
||||||
|| sin_n_embd * 2 != n_embd
|
|| sin_n_embd * 2 != n_embd
|
||||||
|| seq_len > cos_seq_len
|
|| seq_len > cos_seq_len
|
||||||
@ -541,14 +592,21 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
Some((o1, o2)) => &sin[o1..o2],
|
Some((o1, o2)) => &sin[o1..o2],
|
||||||
};
|
};
|
||||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||||
|
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
|
||||||
let el_count = b * h * t * d;
|
let el_count = b * h * t * d;
|
||||||
let mut dst = vec![T::zero(); el_count];
|
let mut dst = vec![T::zero(); el_count];
|
||||||
src.par_chunks(t * h * d)
|
src.par_chunks(t * h * d)
|
||||||
.zip(dst.par_chunks_mut(t * h * d))
|
.zip(dst.par_chunks_mut(t * h * d))
|
||||||
.for_each(|(src, dst)| {
|
.enumerate()
|
||||||
|
.for_each(|(b_i, (src, dst))| {
|
||||||
for i_t in 0..t {
|
for i_t in 0..t {
|
||||||
for i_d in 0..d / 2 {
|
for i_d in 0..d / 2 {
|
||||||
let i_cs = i_t * (d / 2) + i_d;
|
let i_cs = i_t * (d / 2) + i_d;
|
||||||
|
let i_cs = if unbatched_rope {
|
||||||
|
i_cs + b_i * t * d / 2
|
||||||
|
} else {
|
||||||
|
i_cs
|
||||||
|
};
|
||||||
for i_h in 0..h {
|
for i_h in 0..h {
|
||||||
let i1 = i_t * h * d + i_h * d + i_d;
|
let i1 = i_t * h * d + i_h * d + i_d;
|
||||||
let i2 = i1 + d / 2;
|
let i2 = i1 + d / 2;
|
||||||
@ -616,6 +674,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
Some((o1, o2)) => sin.slice(o1..o2),
|
Some((o1, o2)) => sin.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||||
|
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||||
|
(h * t * d) as u32
|
||||||
|
} else {
|
||||||
|
0u32
|
||||||
|
};
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||||
@ -626,7 +689,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
builder.arg(&cos);
|
builder.arg(&cos);
|
||||||
builder.arg(&sin);
|
builder.arg(&sin);
|
||||||
builder.arg(&dst);
|
builder.arg(&dst);
|
||||||
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
|
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
@ -683,6 +746,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
|
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
|
||||||
};
|
};
|
||||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||||
|
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||||
|
h * t * d
|
||||||
|
} else {
|
||||||
|
0usize
|
||||||
|
};
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
|
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
|
||||||
candle_metal_kernels::call_rope_thd(
|
candle_metal_kernels::call_rope_thd(
|
||||||
@ -694,6 +762,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
t,
|
t,
|
||||||
h,
|
h,
|
||||||
d,
|
d,
|
||||||
|
stride_b,
|
||||||
src.buffer(),
|
src.buffer(),
|
||||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||||
cos.buffer(),
|
cos.buffer(),
|
||||||
@ -709,9 +778,9 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
|
let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
|
||||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
|
||||||
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
|
||||||
if cos_n_embd * 2 != n_embd
|
if cos_n_embd * 2 != n_embd
|
||||||
|| sin_n_embd * 2 != n_embd
|
|| sin_n_embd * 2 != n_embd
|
||||||
|| seq_len > cos_seq_len
|
|| seq_len > cos_seq_len
|
||||||
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
|
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
|
||||||
|
|
||||||
fn softmax(device: &Device) -> Result<()> {
|
fn softmax(device: &Device) -> Result<()> {
|
||||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||||
@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
assert!(sum_diff < 1e-4);
|
assert!(sum_diff < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with a 3d cos/sin
|
||||||
|
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
|
.map(|_| rng.random::<f32>())
|
||||||
|
.collect();
|
||||||
|
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
|
.map(|_| rng.random::<f32>())
|
||||||
|
.collect();
|
||||||
|
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||||
|
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||||
|
let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?;
|
||||||
|
let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?;
|
||||||
|
|
||||||
|
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||||
|
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||||
|
let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?;
|
||||||
|
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||||
|
let sum_diff = (both_rope - both_rope2)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(sum_diff, 0.);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
assert!(sum_diff < 1e-4);
|
assert!(sum_diff < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with a 3d cos/sin
|
||||||
|
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
|
.map(|_| rng.random::<f32>())
|
||||||
|
.collect();
|
||||||
|
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
|
.map(|_| rng.random::<f32>())
|
||||||
|
.collect();
|
||||||
|
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||||
|
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||||
|
let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?;
|
||||||
|
let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?;
|
||||||
|
|
||||||
|
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||||
|
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||||
|
let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?;
|
||||||
|
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||||
|
let sum_diff = (both_rope - both_rope2)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(sum_diff, 0.);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
assert!(sum_diff < 1e-4);
|
assert!(sum_diff < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with a 3d cos/sin
|
||||||
|
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
|
.map(|_| rng.random::<f32>())
|
||||||
|
.collect();
|
||||||
|
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
|
.map(|_| rng.random::<f32>())
|
||||||
|
.collect();
|
||||||
|
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||||
|
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||||
|
let rope1 = {
|
||||||
|
let src = src.transpose(1, 2)?.contiguous()?;
|
||||||
|
candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)?
|
||||||
|
};
|
||||||
|
let rope2 = {
|
||||||
|
let src = src.transpose(1, 2)?.contiguous()?;
|
||||||
|
candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||||
|
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||||
|
let both_rope = {
|
||||||
|
let src = src.transpose(1, 2)?.contiguous()?;
|
||||||
|
candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)?
|
||||||
|
};
|
||||||
|
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||||
|
let sum_diff = (both_rope - both_rope2)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(sum_diff, 0.);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user