mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Use the new rope kernel in mistral. (#1937)
* Use the new rope kernel in mistral. * Compute the cos and sin with full precision. * Bugfix.
This commit is contained in:
@ -150,7 +150,7 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
||||
template <typename T>
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx > bh * td) return;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t rope_idx = idx % (td / 2);
|
||||
T c = cos[rope_idx];
|
||||
@ -163,7 +163,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
|
||||
template <typename T>
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx > bh * td) return;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t i_bh = idx / (td / 2);
|
||||
uint32_t i_td = idx - (td / 2) * i_bh;
|
||||
|
Reference in New Issue
Block a user