mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Support for "unbatched" rope. (#2926)
* Support for (un)-batched rope. * Use 3d rope in the rope/ropei/rope_thd functions. * Get the CPU versions to work. * Fix the cuda version. * Adapt the metal side. * Fix the metal tests.
This commit is contained in:
@ -991,6 +991,7 @@ pub fn call_rope_i(
|
||||
kernel_name: &'static str,
|
||||
bh: usize,
|
||||
td: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -1009,6 +1010,7 @@ pub fn call_rope_i(
|
||||
(
|
||||
bh,
|
||||
td,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -1034,6 +1036,7 @@ pub fn call_rope_thd(
|
||||
t: usize,
|
||||
h: usize,
|
||||
d: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -1054,6 +1057,7 @@ pub fn call_rope_thd(
|
||||
t,
|
||||
h,
|
||||
d,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -1078,6 +1082,7 @@ pub fn call_rope(
|
||||
bh: usize,
|
||||
td: usize,
|
||||
d: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -1097,6 +1102,7 @@ pub fn call_rope(
|
||||
bh,
|
||||
td,
|
||||
d,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
|
Reference in New Issue
Block a user