mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Compare commits
7 Commits
bf16_metal
...
vocos
Author | SHA1 | Date | |
---|---|---|---|
3f3730b657 | |||
058a910d0e | |||
26fe162ab5 | |||
121a71e01f | |||
2d5f2a728d | |||
68f7655895 | |||
b60064780d |
@ -75,6 +75,9 @@ We also provide a some command line based examples using state of the art models
|
|||||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||||
much faster inference.
|
much faster inference.
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||||
|
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
||||||
|
- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||||
|
performance.
|
||||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||||
@ -193,6 +196,8 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
- Bert.
|
- Bert.
|
||||||
- Yi-6B and Yi-34B.
|
- Yi-6B and Yi-34B.
|
||||||
|
- Qwen1.5.
|
||||||
|
- RWKV.
|
||||||
- Quantized LLMs.
|
- Quantized LLMs.
|
||||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||||
- Mistral 7b, and 7b instruct.
|
- Mistral 7b, and 7b instruct.
|
||||||
@ -210,7 +215,8 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- BLIP.
|
- BLIP.
|
||||||
- TrOCR.
|
- TrOCR.
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT.
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||||
|
ConvNeXTv2.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
|
@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
|||||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||||
|
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||||
|
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = -v
|
||||||
|
}
|
||||||
|
vs_exp_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = v / (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = -v
|
||||||
|
}
|
||||||
|
vd_exp_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = v / (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! binary_op {
|
macro_rules! binary_op {
|
||||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -589,6 +589,13 @@ impl Tensor {
|
|||||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||||
}
|
}
|
||||||
|
Op::Unary(arg, UnaryOp::Silu) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||||
|
let sigmoid_arg = (*node / arg)?;
|
||||||
|
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||||
|
}
|
||||||
Op::Elu(arg, alpha) => {
|
Op::Elu(arg, alpha) => {
|
||||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -679,6 +679,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||||
|
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||||
@ -696,6 +697,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||||
|
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||||
@ -730,6 +732,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||||
|
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||||
@ -745,6 +748,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||||
("uerf", DType::F16) => strided::erf::HALF,
|
("uerf", DType::F16) => strided::erf::HALF,
|
||||||
|
("usilu", DType::F16) => strided::silu::HALF,
|
||||||
("uabs", DType::F16) => strided::abs::HALF,
|
("uabs", DType::F16) => strided::abs::HALF,
|
||||||
("uceil", DType::F16) => strided::ceil::HALF,
|
("uceil", DType::F16) => strided::ceil::HALF,
|
||||||
("ufloor", DType::F16) => strided::floor::HALF,
|
("ufloor", DType::F16) => strided::floor::HALF,
|
||||||
|
@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
|||||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||||
|
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||||
|
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = -v
|
||||||
|
}
|
||||||
|
vs_exp_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = v / (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = -v
|
||||||
|
}
|
||||||
|
vd_exp_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = v / (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! binary_op {
|
macro_rules! binary_op {
|
||||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -61,6 +61,7 @@ pub enum UnaryOp {
|
|||||||
GeluErf,
|
GeluErf,
|
||||||
Erf,
|
Erf,
|
||||||
Relu,
|
Relu,
|
||||||
|
Silu,
|
||||||
Tanh,
|
Tanh,
|
||||||
Floor,
|
Floor,
|
||||||
Ceil,
|
Ceil,
|
||||||
@ -390,6 +391,7 @@ pub(crate) struct Gelu;
|
|||||||
pub(crate) struct GeluErf;
|
pub(crate) struct GeluErf;
|
||||||
pub(crate) struct Erf;
|
pub(crate) struct Erf;
|
||||||
pub(crate) struct Relu;
|
pub(crate) struct Relu;
|
||||||
|
pub(crate) struct Silu;
|
||||||
pub(crate) struct Tanh;
|
pub(crate) struct Tanh;
|
||||||
pub(crate) struct Floor;
|
pub(crate) struct Floor;
|
||||||
pub(crate) struct Ceil;
|
pub(crate) struct Ceil;
|
||||||
@ -724,6 +726,77 @@ impl UnaryOpT for Erf {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Silu operation
|
||||||
|
impl UnaryOpT for Silu {
|
||||||
|
const NAME: &'static str = "silu";
|
||||||
|
const V: Self = Silu;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
v / (bf16::ONE + (-v).exp())
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
v / (f16::ONE + (-v).exp())
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
v / (1.0 + (-v).exp())
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
v / (1.0 + (-v).exp())
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(_: u8) -> u8 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(_: u32) -> u32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(_: i64) -> i64 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
const KERNEL: &'static str = "usilu";
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
const F32_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||||
|
crate::mkl::vs_silu(xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
const F64_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||||
|
crate::mkl::vd_silu(xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F32_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||||
|
crate::accelerate::vs_silu(xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F64_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||||
|
crate::accelerate::vd_silu(xs, ys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UnaryOpT for Abs {
|
impl UnaryOpT for Abs {
|
||||||
const NAME: &'static str = "abs";
|
const NAME: &'static str = "abs";
|
||||||
const KERNEL: &'static str = "uabs";
|
const KERNEL: &'static str = "uabs";
|
||||||
|
@ -508,6 +508,7 @@ impl Tensor {
|
|||||||
unary_op!(gelu_erf, GeluErf);
|
unary_op!(gelu_erf, GeluErf);
|
||||||
unary_op!(erf, Erf);
|
unary_op!(erf, Erf);
|
||||||
unary_op!(relu, Relu);
|
unary_op!(relu, Relu);
|
||||||
|
unary_op!(silu, Silu);
|
||||||
unary_op!(ceil, Ceil);
|
unary_op!(ceil, Ceil);
|
||||||
unary_op!(floor, Floor);
|
unary_op!(floor, Floor);
|
||||||
unary_op!(round, Round);
|
unary_op!(round, Round);
|
||||||
|
@ -270,6 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// testing compared to pytorch nn.Silu()
|
||||||
|
let y = x.silu()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[2.8577, 0.7311, 3.9281, 0.0806]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||||
|
);
|
||||||
|
|
||||||
// manually checked: see comments
|
// manually checked: see comments
|
||||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
@ -120,6 +120,13 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
||||||
|
[
|
||||||
|
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
||||||
|
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
||||||
|
]
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# candle-convnext
|
# candle-convnext
|
||||||
|
|
||||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
|
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
|
||||||
|
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
|
||||||
|
|
||||||
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
||||||
classification head has been trained on the ImageNet dataset and returns the
|
classification head has been trained on the ImageNet dataset and returns the
|
||||||
|
@ -12,38 +12,62 @@ use candle_transformers::models::convnext;
|
|||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
|
Atto,
|
||||||
|
Femto,
|
||||||
|
Pico,
|
||||||
|
Nano,
|
||||||
Tiny,
|
Tiny,
|
||||||
Small,
|
Small,
|
||||||
Base,
|
Base,
|
||||||
Large,
|
Large,
|
||||||
|
AttoV2,
|
||||||
|
FemtoV2,
|
||||||
|
PicoV2,
|
||||||
|
NanoV2,
|
||||||
|
TinyV2,
|
||||||
|
BaseV2,
|
||||||
|
LargeV2,
|
||||||
XLarge,
|
XLarge,
|
||||||
|
Huge,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
fn model_filename(&self) -> String {
|
fn model_filename(&self) -> String {
|
||||||
let name = match self {
|
let name = match self {
|
||||||
Self::Tiny => "tiny",
|
Self::Atto => "convnext_atto.d2_in1k",
|
||||||
Self::Small => "small",
|
Self::Femto => "convnext_femto.d1_in1k",
|
||||||
Self::Base => "base",
|
Self::Pico => "convnext_pico.d1_in1k",
|
||||||
Self::Large => "large",
|
Self::Nano => "convnext_nano.d1h_in1k",
|
||||||
Self::XLarge => "xlarge",
|
Self::Tiny => "convnext_tiny.fb_in1k",
|
||||||
};
|
Self::Small => "convnext_small.fb_in1k",
|
||||||
// The XLarge model only has an ImageNet-22K variant
|
Self::Base => "convnext_base.fb_in1k",
|
||||||
let variant = match self {
|
Self::Large => "convnext_large.fb_in1k",
|
||||||
Self::XLarge => "fb_in22k_ft_in1k",
|
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
|
||||||
_ => "fb_in1k",
|
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
|
||||||
|
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
|
||||||
|
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
|
||||||
|
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
|
||||||
|
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
|
||||||
|
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
|
||||||
|
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
|
||||||
|
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
|
||||||
};
|
};
|
||||||
|
|
||||||
format!("timm/convnext_{name}.{variant}")
|
format!("timm/{name}")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn config(&self) -> convnext::Config {
|
fn config(&self) -> convnext::Config {
|
||||||
match self {
|
match self {
|
||||||
Self::Tiny => convnext::Config::tiny(),
|
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
|
||||||
|
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
|
||||||
|
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
|
||||||
|
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
|
||||||
|
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
|
||||||
Self::Small => convnext::Config::small(),
|
Self::Small => convnext::Config::small(),
|
||||||
Self::Base => convnext::Config::base(),
|
Self::Base | Self::BaseV2 => convnext::Config::base(),
|
||||||
Self::Large => convnext::Config::large(),
|
Self::Large | Self::LargeV2 => convnext::Config::large(),
|
||||||
Self::XLarge => convnext::Config::xlarge(),
|
Self::XLarge => convnext::Config::xlarge(),
|
||||||
|
Self::Huge => convnext::Config::huge(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
17
candle-examples/examples/rwkv/README.md
Normal file
17
candle-examples/examples/rwkv/README.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
## candle-rwkv
|
||||||
|
|
||||||
|
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
|
||||||
|
with performance on par with transformer architectures. Several variants are
|
||||||
|
available, candle implements the v5 version and can be used with Eagle 7B([blog
|
||||||
|
post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "
|
||||||
|
avx: true, neon: false, simd128: false, f16c: true
|
||||||
|
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||||
|
The smallest prime is ϕ(2) = 2.
|
||||||
|
The smallest composite is ϕ(3) = 3.
|
||||||
|
The smallest perfect number is ϕ(5) = 5.
|
||||||
|
The smallest perfect square is ϕ(4) = 4.
|
||||||
|
The smallest perfect cube is ϕ(6) = 6.
|
||||||
|
```
|
265
candle-examples/examples/rwkv/main.rs
Normal file
265
candle-examples/examples/rwkv/main.rs
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
config: Config,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
config: Config,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
let mut tokens = self.tokenizer.encode(prompt)?;
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let mut state = State::new(1, &self.config, &self.device)?;
|
||||||
|
let mut next_logits = None;
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
let input = Tensor::new(&[[t]], &self.device)?;
|
||||||
|
let logits = self.model.forward(&input, &mut state)?;
|
||||||
|
next_logits = Some(logits);
|
||||||
|
print!("{}", self.tokenizer.decode(&[t])?)
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for _ in 0..sample_len {
|
||||||
|
let logits = match next_logits.as_ref() {
|
||||||
|
Some(logits) => logits,
|
||||||
|
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||||
|
};
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
print!("{}", self.tokenizer.decode(&[next_token])?);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let input = Tensor::new(&[[next_token]], &self.device)?;
|
||||||
|
next_logits = Some(self.model.forward(&input, &mut state)?)
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||||
|
enum Which {
|
||||||
|
Eagle7b,
|
||||||
|
World1b5,
|
||||||
|
World3b,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Which {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{:?}", self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_id(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||||
|
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||||
|
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn revision(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Eagle7b => "refs/pr/1",
|
||||||
|
Self::World1b5 | Self::World3b => "refs/pr/2",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "world1b5")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
args.model_id
|
||||||
|
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision
|
||||||
|
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||||
|
));
|
||||||
|
let tokenizer = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => api
|
||||||
|
.model("lmz/candle-rwkv".to_string())
|
||||||
|
.get("rwkv_vocab_v20230424.json")?,
|
||||||
|
};
|
||||||
|
let config_filename = match args.config_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("config.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => {
|
||||||
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::new(tokenizer)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -55,6 +55,11 @@ __device__ __forceinline__ T relu_fwd(T x) {
|
|||||||
return maxg(x, zero);
|
return maxg(x, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T silu_fwd(T x) {
|
||||||
|
return x / (static_cast<T>(1) + expg(-x));
|
||||||
|
}
|
||||||
|
|
||||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const size_t numel, \
|
const size_t numel, \
|
||||||
@ -103,6 +108,7 @@ UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
|||||||
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
|
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
|
||||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||||
|
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -127,6 +133,7 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
|||||||
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
|
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
|
||||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||||
|
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -173,5 +180,7 @@ UNARY_OP(float, urelu_f32, relu_fwd(x))
|
|||||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||||
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
||||||
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
|
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
|
||||||
|
UNARY_OP(float, usilu_f32, silu_fwd(x))
|
||||||
|
UNARY_OP(double, usilu_f64, silu_fwd(x))
|
||||||
UNARY_OP1(float, upowf_f32, powg(x, param))
|
UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
xcrun metal -c src/gemm/kernels/steel_gemm.metal -I src/
|
|
||||||
xcrun metallib steel_gemm.air -o src/gemm/steel_gemm.metallib
|
|
@ -1,317 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
|
||||||
|
|
||||||
typedef bfloat bfloat16_t;
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Helpers
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
|
||||||
// Check for nan
|
|
||||||
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
|
||||||
_fp_encoding_traits<float>::inf_mask) {
|
|
||||||
return uint16_t(as_type<uint32_t>(0x7FC0));
|
|
||||||
}
|
|
||||||
// Take bits
|
|
||||||
uint32_t float_bits = as_type<uint32_t>(x);
|
|
||||||
|
|
||||||
// Round to nearest even
|
|
||||||
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
|
||||||
|
|
||||||
// Take upper 16 bits
|
|
||||||
return float_bits >> 16;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
|
||||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
|
||||||
return as_type<float>((uint32_t)x << 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct _MLX_BFloat16;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_to_bfloat =
|
|
||||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_from_bfloat =
|
|
||||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat struct
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
struct _MLX_BFloat16 {
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Constructors
|
|
||||||
uint16_t bits_;
|
|
||||||
_MLX_BFloat16() thread = default;
|
|
||||||
_MLX_BFloat16() threadgroup = default;
|
|
||||||
_MLX_BFloat16() device = default;
|
|
||||||
_MLX_BFloat16() constant = default;
|
|
||||||
|
|
||||||
struct bits_to_bfloat_struct {};
|
|
||||||
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
|
||||||
return bits_to_bfloat_struct();
|
|
||||||
}
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
|
||||||
: bits_(bits) {}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Conversions to bfloat
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Conversions from bfloat
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const thread {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const threadgroup {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const device {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const constant {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat operators
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Unary ops
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
|
||||||
return -static_cast<float>(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Binary operators
|
|
||||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
|
||||||
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
|
||||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
|
||||||
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
|
||||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
||||||
} \
|
|
||||||
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
|
||||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Arithmetic Operators
|
|
||||||
#define bfloat_binop(_op_, _operator_) \
|
|
||||||
bfloat_binop_base( \
|
|
||||||
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
|
||||||
|
|
||||||
bfloat_binop(+, operator+);
|
|
||||||
bfloat_binop(-, operator-);
|
|
||||||
bfloat_binop(*, operator*);
|
|
||||||
bfloat_binop(/, operator/);
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Comparison ops
|
|
||||||
#define bfloat_compop(__op__, __operator__) \
|
|
||||||
bfloat_binop_base( \
|
|
||||||
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
|
||||||
|
|
||||||
bfloat_compop(>, operator>);
|
|
||||||
bfloat_compop(<, operator<);
|
|
||||||
bfloat_compop(>=, operator>=);
|
|
||||||
bfloat_compop(<=, operator<=);
|
|
||||||
bfloat_compop(==, operator==);
|
|
||||||
bfloat_compop(!=, operator!=);
|
|
||||||
|
|
||||||
#undef bfloat_compop
|
|
||||||
#undef bfloat_binop_base
|
|
||||||
#undef bfloat_binop_helper
|
|
||||||
#undef bfloat_binop
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Inplace Operators
|
|
||||||
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
|
||||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
|
||||||
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
|
||||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
||||||
return lhs; \
|
|
||||||
} \
|
|
||||||
constexpr METAL_FUNC addr_space itype& __operator__( \
|
|
||||||
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
|
||||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
||||||
return lhs; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
|
||||||
|
|
||||||
#define bfloat_inplace_op(itype) \
|
|
||||||
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
|
||||||
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
|
||||||
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
|
||||||
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
|
||||||
|
|
||||||
bfloat_inplace_op(float);
|
|
||||||
bfloat_inplace_op(half);
|
|
||||||
bfloat_inplace_op(int16_t);
|
|
||||||
bfloat_inplace_op(int32_t);
|
|
||||||
bfloat_inplace_op(int64_t);
|
|
||||||
bfloat_inplace_op(uint16_t);
|
|
||||||
bfloat_inplace_op(uint32_t);
|
|
||||||
bfloat_inplace_op(uint64_t);
|
|
||||||
|
|
||||||
#undef bfloat_inplace_op_helper
|
|
||||||
#undef bfloat_inplace_op_addr_space_helper
|
|
||||||
#undef bfloat_inplace_op
|
|
||||||
|
|
||||||
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
|
||||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
|
||||||
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
|
||||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
||||||
return lhs; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
|
||||||
|
|
||||||
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
|
||||||
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
|
||||||
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
|
||||||
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
|
||||||
|
|
||||||
#undef bfloat_inplace_op_helper
|
|
||||||
#undef bfloat_inplace_op_addr_space_helper
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat typedef
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
typedef struct _MLX_BFloat16 bfloat16_t;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat numeric limits
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#pragma METAL internals : enable
|
|
||||||
|
|
||||||
namespace metal {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
|
||||||
static constexpr constant int digits = 8;
|
|
||||||
static constexpr constant int digits10 = 2;
|
|
||||||
static constexpr constant int max_digits10 = 4;
|
|
||||||
static constexpr constant int radix = 2;
|
|
||||||
static constexpr constant int min_exponent = -125;
|
|
||||||
static constexpr constant int min_exponent10 = -37;
|
|
||||||
static constexpr constant int max_exponent = 128;
|
|
||||||
static constexpr constant int max_exponent10 = 38;
|
|
||||||
|
|
||||||
static constexpr bfloat16_t min() {
|
|
||||||
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t lowest() {
|
|
||||||
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t max() {
|
|
||||||
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t epsilon() {
|
|
||||||
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t round_error() {
|
|
||||||
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t infinity() {
|
|
||||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t quiet_NaN() {
|
|
||||||
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t signaling_NaN() {
|
|
||||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t denorm_min() {
|
|
||||||
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
|
||||||
return x != x;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace metal
|
|
||||||
|
|
||||||
#pragma METAL internals : disable
|
|
||||||
|
|
||||||
#endif // defined(__HAVE_BFLOAT__)
|
|
||||||
|
|
||||||
#include "gemm/bf16_math.h"
|
|
@ -1,394 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "gemm/bf16.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Metal math for bfloat16
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/*
|
|
||||||
|
|
||||||
Following the Metal Shading Language Specification (Metal 3.1)
|
|
||||||
|
|
||||||
"bfloat is an extended itypeing point type that only allows implicit conversion
|
|
||||||
to a type of greater itypeing point rank. While bfloat can be implicitly
|
|
||||||
converted to itype, it cannot be implicitly converted to half, and neither
|
|
||||||
itype nor half can be implicitly converted to bfloat."
|
|
||||||
|
|
||||||
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
|
||||||
for bfloat and calling with an argument of type bfloat will result in that
|
|
||||||
argument getting implicitly converted to itype which then returns an output
|
|
||||||
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
|
||||||
|
|
||||||
This leads to situations where
|
|
||||||
bfloat a = 5.0bf;
|
|
||||||
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
|
||||||
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
|
||||||
|
|
||||||
For the moment, I will be adding overloaded instantiations of the math
|
|
||||||
functions to accordingly automatically handle the casting
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype abs(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype acos(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype acosh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype asin(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype asinh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype atan(itype y_over_x) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype atan2(itype y, itype x) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype atanh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype ceil(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype cos(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype cosh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype cospi(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype divide(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype exp(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype exp10(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype exp2(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fabs(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fdim(itype x, itype y) { \
|
|
||||||
ctype t = static_cast<ctype>(x - y); \
|
|
||||||
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype floor(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fma( \
|
|
||||||
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmax(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmax3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmedian3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmin(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmin3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmod(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fract(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
|
||||||
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype ldexp(itype x, int k) { \
|
|
||||||
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype log(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype log10(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype log2(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype max(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmax3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmedian3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype min(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmin3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype nextafter(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype pow(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype powr(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype rint(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype round(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype rsqrt(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype sin(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype sinh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype sinpi(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype sqrt(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype tan(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype tanh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype tanpi(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype trunc(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace metal {
|
|
||||||
|
|
||||||
instantiate_metal_math_funcs(
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
float,
|
|
||||||
__METAL_MAYBE_FAST_MATH__);
|
|
||||||
|
|
||||||
namespace fast {
|
|
||||||
|
|
||||||
instantiate_metal_math_funcs(
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
float,
|
|
||||||
__METAL_FAST_MATH__);
|
|
||||||
|
|
||||||
} // namespace fast
|
|
||||||
|
|
||||||
namespace precise {
|
|
||||||
|
|
||||||
instantiate_metal_math_funcs(
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
float,
|
|
||||||
__METAL_PRECISE_MATH__);
|
|
||||||
|
|
||||||
} // namespace precise
|
|
||||||
|
|
||||||
} // namespace metal
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Metal simd for bfloat16
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_metal_simd_comm_funcs( \
|
|
||||||
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
|
||||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
|
||||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
|
||||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
|
||||||
itype data, itype filling_data, ushort delta) { \
|
|
||||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
|
||||||
itype_to_ctype(data), \
|
|
||||||
itype_to_ctype(filling_data), \
|
|
||||||
delta, \
|
|
||||||
__metal_get_simdgroup_size(ushort()))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
|
||||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
|
||||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
|
||||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
|
||||||
itype data, itype filling_data, ushort delta) { \
|
|
||||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
|
||||||
itype_to_ctype(data), \
|
|
||||||
itype_to_ctype(filling_data), \
|
|
||||||
delta, \
|
|
||||||
__metal_get_simdgroup_size(ushort()))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_max(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_min(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_product(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_sum(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_xor(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
|
||||||
|
|
||||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
|
||||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
#define bfloat16_to_uint16(x) x.bits_
|
|
||||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace metal {
|
|
||||||
|
|
||||||
instantiate_metal_simd_comm_funcs(
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
uint16_t,
|
|
||||||
bfloat16_to_uint16,
|
|
||||||
uint16_to_bfloat16);
|
|
||||||
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
|
||||||
|
|
||||||
} // namespace metal
|
|
@ -1,131 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
struct complex64_t;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_to_complex64 =
|
|
||||||
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_from_complex64 =
|
|
||||||
!is_same_v<T, complex64_t> &&
|
|
||||||
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
|
||||||
|
|
||||||
struct complex64_t {
|
|
||||||
float real;
|
|
||||||
float imag;
|
|
||||||
|
|
||||||
// Constructors
|
|
||||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
|
||||||
|
|
||||||
// Conversions to complex64_t
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
||||||
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
||||||
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
||||||
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
||||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
|
||||||
|
|
||||||
// Conversions from complex64_t
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
||||||
constexpr operator T() const thread {
|
|
||||||
return static_cast<T>(real);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
||||||
constexpr operator T() const threadgroup {
|
|
||||||
return static_cast<T>(real);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
||||||
constexpr operator T() const device {
|
|
||||||
return static_cast<T>(real);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
||||||
constexpr operator T() const constant {
|
|
||||||
return static_cast<T>(real);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr complex64_t operator-(complex64_t x) {
|
|
||||||
return {-x.real, -x.imag};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
|
||||||
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator>(complex64_t a, complex64_t b) {
|
|
||||||
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
|
||||||
return operator>=(b, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator<(complex64_t a, complex64_t b) {
|
|
||||||
return operator>(b, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator==(complex64_t a, complex64_t b) {
|
|
||||||
return a.real == b.real && a.imag == b.imag;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
|
||||||
return {a.real + b.real, a.imag + b.imag};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
|
||||||
return {a.real - b.real, a.imag - b.imag};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
|
||||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
|
||||||
auto denom = b.real * b.real + b.imag * b.imag;
|
|
||||||
auto x = a.real * b.real + a.imag * b.imag;
|
|
||||||
auto y = a.imag * b.real - a.real * b.imag;
|
|
||||||
return {x / denom, y / denom};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
|
||||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
|
||||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
|
||||||
if (real != 0 && (real < 0 != b.real < 0)) {
|
|
||||||
real += b.real;
|
|
||||||
}
|
|
||||||
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
|
||||||
imag += b.imag;
|
|
||||||
}
|
|
||||||
return {real, imag};
|
|
||||||
}
|
|
@ -1,292 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "gemm/loader.h"
|
|
||||||
#include "gemm/mma.h"
|
|
||||||
#include "gemm/transforms.h"
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel class
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
|
||||||
struct LoopAlignment {};
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename U,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned,
|
|
||||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
|
||||||
typename Epilogue = TransformNone<U, AccumType>>
|
|
||||||
struct GEMMKernel {
|
|
||||||
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
|
||||||
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
|
||||||
STEEL_CONST short tgp_mem_size_a =
|
|
||||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
|
||||||
STEEL_CONST short tgp_mem_size_b =
|
|
||||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
|
||||||
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
|
||||||
|
|
||||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
|
||||||
|
|
||||||
using loader_a_t = BlockLoader<
|
|
||||||
T,
|
|
||||||
transpose_a ? BK : BM,
|
|
||||||
transpose_a ? BM : BK,
|
|
||||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
|
||||||
!transpose_a,
|
|
||||||
tgp_size>;
|
|
||||||
using loader_b_t = BlockLoader<
|
|
||||||
T,
|
|
||||||
transpose_b ? BN : BK,
|
|
||||||
transpose_b ? BK : BN,
|
|
||||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
|
||||||
transpose_b,
|
|
||||||
tgp_size>;
|
|
||||||
using mma_t = BlockMMA<
|
|
||||||
T,
|
|
||||||
U,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
WM,
|
|
||||||
WN,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
|
||||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
|
||||||
AccumType,
|
|
||||||
Epilogue>;
|
|
||||||
|
|
||||||
/* Main kernel function */
|
|
||||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
|
||||||
static METAL_FUNC void gemm_loop(
|
|
||||||
threadgroup T* As [[threadgroup(0)]],
|
|
||||||
threadgroup T* Bs [[threadgroup(1)]],
|
|
||||||
const int gemm_k_iterations,
|
|
||||||
thread loader_a_t& loader_a,
|
|
||||||
thread loader_b_t& loader_b,
|
|
||||||
thread mma_t& mma_op,
|
|
||||||
thread const short& tgp_bm,
|
|
||||||
thread const short& tgp_bn,
|
|
||||||
thread const short& lbk,
|
|
||||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
|
||||||
// Appease the compiler
|
|
||||||
(void)l;
|
|
||||||
|
|
||||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
|
||||||
|
|
||||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
|
||||||
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
if (M_aligned) {
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (N_aligned) {
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!K_aligned_) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
short2 tile_dims_A_last =
|
|
||||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
|
||||||
short2 tile_dims_B_last =
|
|
||||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A_last);
|
|
||||||
loader_b.load_safe(tile_dims_B_last);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Main kernel function */
|
|
||||||
static METAL_FUNC void run(
|
|
||||||
const device T* A [[buffer(0)]],
|
|
||||||
const device T* B [[buffer(1)]],
|
|
||||||
device U* C [[buffer(2)]],
|
|
||||||
const constant GEMMParams* params [[buffer(3)]],
|
|
||||||
threadgroup T* As [[threadgroup(0)]],
|
|
||||||
threadgroup T* Bs [[threadgroup(1)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
// Pacifying compiler
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
||||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
||||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
||||||
|
|
||||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
|
|
||||||
A += transpose_a ? c_row : c_row * params->lda;
|
|
||||||
B += transpose_b ? c_col * params->ldb : c_col;
|
|
||||||
C += c_row * params->ldc + c_col;
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
||||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MNK aligned loop
|
|
||||||
if (MN_aligned) {
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Loop tail
|
|
||||||
if (!K_aligned) {
|
|
||||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
|
||||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
mma_op.store_result(C, params->ldc);
|
|
||||||
return;
|
|
||||||
|
|
||||||
}
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MN unaligned loop
|
|
||||||
else { // Loop over K - unaligned case
|
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
|
||||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
|
|
||||||
if (tgp_bm == BM && tgp_bn == BN) {
|
|
||||||
gemm_loop<true, true, K_aligned>(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk);
|
|
||||||
|
|
||||||
mma_op.store_result(C, params->ldc);
|
|
||||||
return;
|
|
||||||
|
|
||||||
} else if (tgp_bn == BN) {
|
|
||||||
gemm_loop<false, true, K_aligned>(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk);
|
|
||||||
|
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
return;
|
|
||||||
|
|
||||||
} else if (tgp_bm == BM) {
|
|
||||||
gemm_loop<true, false, K_aligned>(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk);
|
|
||||||
|
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
return;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
gemm_loop<false, false, K_aligned>(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk);
|
|
||||||
|
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
@ -1,5 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "params.h"
|
|
@ -1,89 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "gemm/bf16.h"
|
|
||||||
#include "gemm/gemm.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
|
||||||
const device T *A [[buffer(0)]],
|
|
||||||
const device T *B [[buffer(1)]],
|
|
||||||
device T *C [[buffer(2)]],
|
|
||||||
const constant GEMMParams* params [[buffer(3)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
||||||
|
|
||||||
// Adjust for batch
|
|
||||||
A += params->batch_stride_a * tid.z;
|
|
||||||
B += params->batch_stride_b * tid.z;
|
|
||||||
C += params->batch_stride_c * tid.z;
|
|
||||||
|
|
||||||
gemm_kernel::run(
|
|
||||||
A, B, C,
|
|
||||||
params,
|
|
||||||
As, Bs,
|
|
||||||
simd_lane_id, simd_group_id, tid, lid
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel initializations
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
|
||||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
|
||||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
|
||||||
const device itype *A [[buffer(0)]], \
|
|
||||||
const device itype *B [[buffer(1)]], \
|
|
||||||
device itype *C [[buffer(2)]], \
|
|
||||||
const constant GEMMParams* params [[buffer(3)]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
|
||||||
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
|
||||||
|
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
|
@ -1,254 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned,
|
|
||||||
typename AccumType = float,
|
|
||||||
typename Epilogue = TransformAdd<T, AccumType>>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
|
|
||||||
const device T *A [[buffer(0)]],
|
|
||||||
const device T *B [[buffer(1)]],
|
|
||||||
const device T *C [[buffer(2)]],
|
|
||||||
device T *D [[buffer(3)]],
|
|
||||||
const constant GEMMAddMMParams* params [[buffer(4)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
|
|
||||||
// Pacifying compiler
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
using gemm_kernel =
|
|
||||||
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
|
||||||
transpose_a, transpose_b,
|
|
||||||
MN_aligned, K_aligned,
|
|
||||||
AccumType, Epilogue>;
|
|
||||||
|
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
||||||
using mma_t = typename gemm_kernel::mma_t;
|
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
||||||
|
|
||||||
// Adjust for batch
|
|
||||||
A += params->batch_stride_a * tid.z;
|
|
||||||
B += params->batch_stride_b * tid.z;
|
|
||||||
C += params->batch_stride_c * tid.z;
|
|
||||||
D += params->batch_stride_d * tid.z;
|
|
||||||
|
|
||||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
||||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
||||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
||||||
|
|
||||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
|
|
||||||
A += transpose_a ? c_row : c_row * params->lda;
|
|
||||||
B += transpose_b ? c_col * params->ldb : c_col;
|
|
||||||
C += c_row * params->ldc + c_col * params->fdc;
|
|
||||||
D += c_row * params->ldd + c_col;
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
||||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
const Epilogue epilogue_op(params->alpha, params->beta);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MNK aligned loop
|
|
||||||
if (MN_aligned) {
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Loop tail
|
|
||||||
if (!K_aligned) {
|
|
||||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
|
||||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
|
||||||
return;
|
|
||||||
|
|
||||||
}
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MN unaligned loop
|
|
||||||
else { // Loop over K - unaligned case
|
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
|
||||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
|
|
||||||
if (tgp_bm == BM && tgp_bn == BN) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, true, K_aligned>{});
|
|
||||||
|
|
||||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
|
||||||
return;
|
|
||||||
|
|
||||||
} else if (tgp_bn == BN) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, true, K_aligned>{});
|
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
|
||||||
D, params->ldd,
|
|
||||||
C, params->ldc, params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op);
|
|
||||||
|
|
||||||
} else if (tgp_bm == BM) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, false, K_aligned>{});
|
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
|
||||||
D, params->ldd,
|
|
||||||
C, params->ldc, params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, false, K_aligned>{});
|
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
|
||||||
D, params->ldd,
|
|
||||||
C, params->ldc, params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel initializations
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
|
||||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
|
||||||
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
|
||||||
const device itype *A [[buffer(0)]], \
|
|
||||||
const device itype *B [[buffer(1)]], \
|
|
||||||
const device itype *C [[buffer(2)]], \
|
|
||||||
device itype *D [[buffer(3)]], \
|
|
||||||
const constant GEMMAddMMParams* params [[buffer(4)]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
|
||||||
|
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
|
||||||
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
|
||||||
|
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
|
@ -1,280 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
typename U,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
|
|
||||||
const device T *A [[buffer(0)]],
|
|
||||||
const device T *B [[buffer(1)]],
|
|
||||||
device U *C [[buffer(2)]],
|
|
||||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
||||||
using mma_t = typename gemm_kernel::mma_t;
|
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
||||||
|
|
||||||
const int tid_x = tid.x;
|
|
||||||
const int tid_y = tid.y;
|
|
||||||
const int tid_z = tid.z;
|
|
||||||
|
|
||||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
const int k_start = params->split_k_partition_size * tid_z;
|
|
||||||
|
|
||||||
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
|
||||||
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
|
||||||
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
||||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
|
||||||
short leftover_bk = params->K % BK;
|
|
||||||
|
|
||||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, true, true>{});
|
|
||||||
} else if (tgp_bn == BN) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, true, true>{});
|
|
||||||
} else if (tgp_bm == BM) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, false, true>{});
|
|
||||||
} else {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, false, true>{});
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
|
||||||
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
|
||||||
if(!K_aligned || gemm_k_iter_remaining > 0)
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iter_remaining,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, false, K_aligned>{});
|
|
||||||
}
|
|
||||||
|
|
||||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
||||||
mma_op.store_result(C, params->ldc);
|
|
||||||
} else {
|
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel initializations
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
|
||||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
|
||||||
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
|
||||||
const device itype *A [[buffer(0)]], \
|
|
||||||
const device itype *B [[buffer(1)]], \
|
|
||||||
device otype *C [[buffer(2)]], \
|
|
||||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
|
||||||
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
|
||||||
|
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Split k accumulation kernel
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename AccT,
|
|
||||||
typename OutT,
|
|
||||||
typename Epilogue = TransformNone<OutT, AccT>>
|
|
||||||
[[kernel]] void gemm_splitk_accum(
|
|
||||||
const device AccT *C_split [[buffer(0)]],
|
|
||||||
device OutT *D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
|
||||||
|
|
||||||
// Ajust D and C
|
|
||||||
D += gid.x + gid.y * ldd;
|
|
||||||
C_split += gid.x + gid.y * ldd;
|
|
||||||
|
|
||||||
int offset = 0;
|
|
||||||
AccT out = 0;
|
|
||||||
|
|
||||||
for(int i = 0; i < k_partitions; i++) {
|
|
||||||
out += C_split[offset];
|
|
||||||
offset += partition_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write output
|
|
||||||
D[0] = Epilogue::apply(out);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename AccT,
|
|
||||||
typename OutT,
|
|
||||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
|
||||||
[[kernel]] void gemm_splitk_accum_axpby(
|
|
||||||
const device AccT *C_split [[buffer(0)]],
|
|
||||||
device OutT *D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
const device OutT *C [[buffer(5)]],
|
|
||||||
const constant int& ldc [[buffer(6)]],
|
|
||||||
const constant int& fdc [[buffer(7)]],
|
|
||||||
const constant float& alpha [[buffer(8)]],
|
|
||||||
const constant float& beta [[buffer(9)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
|
||||||
|
|
||||||
// Ajust D and C
|
|
||||||
C += gid.x * fdc + gid.y * ldc;
|
|
||||||
D += gid.x + gid.y * ldd;
|
|
||||||
C_split += gid.x + gid.y * ldd;
|
|
||||||
|
|
||||||
int offset = 0;
|
|
||||||
AccT out = 0;
|
|
||||||
|
|
||||||
for(int i = 0; i < k_partitions; i++) {
|
|
||||||
out += C_split[offset];
|
|
||||||
offset += partition_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write output
|
|
||||||
Epilogue op(alpha, beta);
|
|
||||||
D[0] = op.apply(out, *C);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_accum(oname, otype, aname, atype) \
|
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
|
||||||
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
|
||||||
const device atype *C_split [[buffer(0)]], \
|
|
||||||
device otype *D [[buffer(1)]], \
|
|
||||||
const constant int& k_partitions [[buffer(2)]], \
|
|
||||||
const constant int& partition_stride [[buffer(3)]], \
|
|
||||||
const constant int& ldd [[buffer(4)]], \
|
|
||||||
uint2 gid [[thread_position_in_grid]]); \
|
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
|
||||||
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
|
||||||
const device atype *C_split [[buffer(0)]], \
|
|
||||||
device otype *D [[buffer(1)]], \
|
|
||||||
const constant int& k_partitions [[buffer(2)]], \
|
|
||||||
const constant int& partition_stride [[buffer(3)]], \
|
|
||||||
const constant int& ldd [[buffer(4)]], \
|
|
||||||
const device otype *C [[buffer(5)]], \
|
|
||||||
const constant int& ldc [[buffer(6)]], \
|
|
||||||
const constant int& fdc [[buffer(7)]], \
|
|
||||||
const constant float& alpha [[buffer(8)]], \
|
|
||||||
const constant float& beta [[buffer(9)]], \
|
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
|
||||||
instantiate_accum(float16, half, float32, float);
|
|
||||||
instantiate_accum(float32, float, float32, float);
|
|
@ -1,125 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "utils2.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Loading helper
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
short BROWS,
|
|
||||||
short BCOLS,
|
|
||||||
short dst_ld,
|
|
||||||
short reduction_dim,
|
|
||||||
short tgp_size,
|
|
||||||
short alignment = 1,
|
|
||||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
|
||||||
short TCOLS = BCOLS / n_reads,
|
|
||||||
short TROWS = tgp_size / TCOLS>
|
|
||||||
struct BlockLoader {
|
|
||||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
|
||||||
STEEL_CONST short vec_size = n_reads;
|
|
||||||
|
|
||||||
// Leading dimension for src
|
|
||||||
const int src_ld;
|
|
||||||
const int tile_stride;
|
|
||||||
|
|
||||||
// Thread location indices
|
|
||||||
const short thread_idx;
|
|
||||||
const short bi;
|
|
||||||
const short bj;
|
|
||||||
|
|
||||||
// threadgroup and device memory
|
|
||||||
threadgroup T* dst;
|
|
||||||
const device T* src;
|
|
||||||
|
|
||||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
|
||||||
uint8_t v[sizeof(T) * vec_size];
|
|
||||||
};
|
|
||||||
|
|
||||||
/* Constructor */
|
|
||||||
METAL_FUNC BlockLoader(
|
|
||||||
const device T* src_,
|
|
||||||
const int src_ld_,
|
|
||||||
threadgroup T* dst_,
|
|
||||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
||||||
: src_ld(src_ld_),
|
|
||||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
|
||||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
||||||
bi(thread_idx / TCOLS),
|
|
||||||
bj(vec_size * (thread_idx % TCOLS)),
|
|
||||||
dst(dst_ + bi * dst_ld + bj),
|
|
||||||
src(src_ + bi * src_ld + bj) {}
|
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - without bound checking */
|
|
||||||
METAL_FUNC void load_unsafe() const {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < BROWS; i += TROWS) {
|
|
||||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
|
||||||
*((const device ReadVector*)(&src[i * src_ld]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - with bound checking */
|
|
||||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
|
||||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
|
||||||
|
|
||||||
// Skip loading if thread has no valid reads
|
|
||||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < BROWS; i += TROWS) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
dst[i * dst_ld + j] = T(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use fast thread memory for bound checks
|
|
||||||
bool tmp_idx[vec_size];
|
|
||||||
T tmp_val[vec_size];
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < BROWS; i += TROWS) {
|
|
||||||
// Make sure tmp_idx only contains valid indices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read valid indices into tmp_val
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zero out uneeded values
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy values to threadgroup memory
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
dst[i * dst_ld + j] = tmp_val[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Iteration helper */
|
|
||||||
METAL_FUNC void next() {
|
|
||||||
src += tile_stride;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
@ -1,264 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "gemm/transforms.h"
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MMA helper
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename U,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
short lda_tgp,
|
|
||||||
short ldb_tgp,
|
|
||||||
typename AccumType = float,
|
|
||||||
typename Epilogue = TransformNone<U, AccumType>>
|
|
||||||
struct BlockMMA {
|
|
||||||
// Warp tile simdgroup matrix strides along M
|
|
||||||
STEEL_CONST short TM_stride = 8 * WM;
|
|
||||||
// Warp tile simdgroup matrix strides along M
|
|
||||||
STEEL_CONST short TN_stride = 8 * WN;
|
|
||||||
|
|
||||||
// Warp tile size along M
|
|
||||||
STEEL_CONST short TM = BM / TM_stride;
|
|
||||||
// Warp tile size along N
|
|
||||||
STEEL_CONST short TN = BN / TN_stride;
|
|
||||||
|
|
||||||
// Strides of A, B along reduction axis
|
|
||||||
STEEL_CONST short simd_stride_a = {
|
|
||||||
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
|
||||||
STEEL_CONST short simd_stride_b = {
|
|
||||||
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
|
||||||
|
|
||||||
// Jump between elements
|
|
||||||
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
|
||||||
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
|
||||||
|
|
||||||
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
|
||||||
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
|
||||||
|
|
||||||
// Simdgroup matrices
|
|
||||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
|
||||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
|
||||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
|
||||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
|
||||||
|
|
||||||
// Offsets within threadgroup
|
|
||||||
const short tm;
|
|
||||||
const short tn;
|
|
||||||
|
|
||||||
short sm;
|
|
||||||
short sn;
|
|
||||||
|
|
||||||
short As_offset;
|
|
||||||
short Bs_offset;
|
|
||||||
|
|
||||||
/* Constructor */
|
|
||||||
METAL_FUNC BlockMMA(
|
|
||||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
||||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
|
||||||
// Determine thread position in simdgroup matrix
|
|
||||||
short qid = simd_lane_id / 4;
|
|
||||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
|
||||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
|
||||||
|
|
||||||
// Determine thread and simdgroup offset
|
|
||||||
As_offset =
|
|
||||||
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
|
||||||
Bs_offset =
|
|
||||||
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
|
||||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
As += As_offset;
|
|
||||||
Bs += Bs_offset;
|
|
||||||
|
|
||||||
// Iterate over BK in blocks of 8
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short kk = 0; kk < BK; kk += 8) {
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Load elements from threadgroup A as simdgroup matrices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
Asimd[i].thread_elements()[0] =
|
|
||||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
|
||||||
Asimd[i].thread_elements()[1] =
|
|
||||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Load elements from threadgroup B as simdgroup matrices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
Bsimd[j].thread_elements()[0] =
|
|
||||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
|
||||||
Bsimd[j].thread_elements()[1] =
|
|
||||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Multiply and accumulate into result simdgroup matrices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(
|
|
||||||
results[i * TN + j_serp],
|
|
||||||
Asimd[i],
|
|
||||||
Bsimd[j_serp],
|
|
||||||
results[i * TN + j_serp]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Progress to next simdgroup tile
|
|
||||||
As += tile_stride_a;
|
|
||||||
Bs += tile_stride_b;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Store results from simdgroup_matrix results into device memory */
|
|
||||||
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
C += (sm + tm) * ldc + tn + sn;
|
|
||||||
|
|
||||||
// Loop over all simdgroup tiles
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
// Get accumulated result and associated offset in C
|
|
||||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
|
||||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
|
||||||
|
|
||||||
// Apply epilogue
|
|
||||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
|
||||||
|
|
||||||
// Write out C
|
|
||||||
C[offset] = outs[0];
|
|
||||||
C[offset + 1] = outs[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void
|
|
||||||
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
C += (sm + tm) * ldc + (tn + sn);
|
|
||||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < TM; i++) {
|
|
||||||
if (i * TM_stride < dst_tile_dims.y) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int j = 0; j < TN; j++) {
|
|
||||||
// Get accumulated result and associated offset in C
|
|
||||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
|
||||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
|
||||||
|
|
||||||
// Apply epilogue and output C
|
|
||||||
if (j * TN_stride < dst_tile_dims.x) {
|
|
||||||
C[offset] = Epilogue::apply(accum[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
|
||||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Store results from simdgroup_matrix results into device memory */
|
|
||||||
METAL_FUNC void store_result(
|
|
||||||
device U* D,
|
|
||||||
const int ldd,
|
|
||||||
const device U* C,
|
|
||||||
const int ldc,
|
|
||||||
const int fdc,
|
|
||||||
thread const Epilogue& epilogue_op) const {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
|
||||||
D += (sm + tm) * ldd + tn + sn;
|
|
||||||
|
|
||||||
// Loop over all simdgroup tiles
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
// Get accumulated result and associated offset in C
|
|
||||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
|
||||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
||||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
||||||
|
|
||||||
// Apply epilogue
|
|
||||||
U outs[2] = {
|
|
||||||
epilogue_op.apply(accum[0], C[offset_c]),
|
|
||||||
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
|
||||||
|
|
||||||
// Write out D
|
|
||||||
D[offset_d] = outs[0];
|
|
||||||
D[offset_d + 1] = outs[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void store_result_safe(
|
|
||||||
device U* D,
|
|
||||||
const int ldd,
|
|
||||||
const device U* C,
|
|
||||||
const int ldc,
|
|
||||||
const int fdc,
|
|
||||||
short2 dst_tile_dims,
|
|
||||||
thread const Epilogue& epilogue_op) const {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
|
||||||
D += (sm + tm) * ldd + tn + sn;
|
|
||||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < TM; i++) {
|
|
||||||
if (i * TM_stride < dst_tile_dims.y) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int j = 0; j < TN; j++) {
|
|
||||||
// Get accumulated result and associated offset in C
|
|
||||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
|
||||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
||||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
||||||
|
|
||||||
// Apply epilogue and output C
|
|
||||||
if (j * TN_stride < dst_tile_dims.x) {
|
|
||||||
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
|
||||||
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
@ -1,79 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM param classes
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
struct GEMMParams {
|
|
||||||
const int M;
|
|
||||||
const int N;
|
|
||||||
const int K;
|
|
||||||
|
|
||||||
const int lda;
|
|
||||||
const int ldb;
|
|
||||||
const int ldc;
|
|
||||||
|
|
||||||
const int tiles_n;
|
|
||||||
const int tiles_m;
|
|
||||||
|
|
||||||
const int batch_stride_a;
|
|
||||||
const int batch_stride_b;
|
|
||||||
const int batch_stride_c;
|
|
||||||
|
|
||||||
const int swizzle_log;
|
|
||||||
const int gemm_k_iterations_aligned;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GEMMSpiltKParams {
|
|
||||||
const int M;
|
|
||||||
const int N;
|
|
||||||
const int K;
|
|
||||||
|
|
||||||
const int lda;
|
|
||||||
const int ldb;
|
|
||||||
const int ldc;
|
|
||||||
|
|
||||||
const int tiles_n;
|
|
||||||
const int tiles_m;
|
|
||||||
|
|
||||||
const int split_k_partitions;
|
|
||||||
const int split_k_partition_stride;
|
|
||||||
const int split_k_partition_size;
|
|
||||||
|
|
||||||
const int gemm_k_iterations_aligned;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GEMMAddMMParams {
|
|
||||||
const int M;
|
|
||||||
const int N;
|
|
||||||
const int K;
|
|
||||||
|
|
||||||
const int lda;
|
|
||||||
const int ldb;
|
|
||||||
const int ldc;
|
|
||||||
const int ldd;
|
|
||||||
|
|
||||||
const int tiles_n;
|
|
||||||
const int tiles_m;
|
|
||||||
|
|
||||||
const int batch_stride_a;
|
|
||||||
const int batch_stride_b;
|
|
||||||
const int batch_stride_c;
|
|
||||||
const int batch_stride_d;
|
|
||||||
|
|
||||||
const int swizzle_log;
|
|
||||||
const int gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
const float alpha;
|
|
||||||
const float beta;
|
|
||||||
|
|
||||||
const int fdc;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
Binary file not shown.
@ -1,63 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Transforms and Epilogues
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
template <typename OutT, typename InT>
|
|
||||||
struct TransformNone {
|
|
||||||
static METAL_FUNC OutT apply(InT x) {
|
|
||||||
return static_cast<OutT>(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
static METAL_FUNC OutT apply(InT x, OutT) {
|
|
||||||
return static_cast<OutT>(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename OutT, typename InT>
|
|
||||||
struct TransformAdd {
|
|
||||||
TransformAdd(const float, const float) {}
|
|
||||||
|
|
||||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
|
||||||
return static_cast<OutT>(x) + c;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename OutT, typename InT>
|
|
||||||
struct TransformAxpby {
|
|
||||||
const float alpha;
|
|
||||||
const float beta;
|
|
||||||
|
|
||||||
TransformAxpby(const float alpha_, const float beta_)
|
|
||||||
: alpha(alpha_), beta(beta_) {}
|
|
||||||
|
|
||||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
|
||||||
return static_cast<OutT>(x * alpha + (beta * c));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct AccumHelper {
|
|
||||||
typedef float accum_type;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BlockSwizzle {
|
|
||||||
static METAL_FUNC int2
|
|
||||||
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
|
||||||
const int tid_x = (tid.x) >> swizzle_log;
|
|
||||||
const int tid_y =
|
|
||||||
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
|
||||||
return int2(tid_x, tid_y);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
@ -1,276 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_math>
|
|
||||||
#include "gemm/bf16.h"
|
|
||||||
#include "gemm/complex.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Type limits utils
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename U>
|
|
||||||
struct Limits {
|
|
||||||
static const constant U max = metal::numeric_limits<U>::max();
|
|
||||||
static const constant U min = metal::numeric_limits<U>::min();
|
|
||||||
static const constant U finite_max = metal::numeric_limits<U>::max();
|
|
||||||
static const constant U finite_min = metal::numeric_limits<U>::min();
|
|
||||||
};
|
|
||||||
|
|
||||||
#define instantiate_default_limit(type) \
|
|
||||||
template <> \
|
|
||||||
struct Limits<type> { \
|
|
||||||
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
|
||||||
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
|
||||||
static constexpr constant type finite_max = \
|
|
||||||
metal::numeric_limits<type>::max(); \
|
|
||||||
static constexpr constant type finite_min = \
|
|
||||||
metal::numeric_limits<type>::min(); \
|
|
||||||
};
|
|
||||||
|
|
||||||
instantiate_default_limit(uint8_t);
|
|
||||||
instantiate_default_limit(uint16_t);
|
|
||||||
instantiate_default_limit(uint32_t);
|
|
||||||
instantiate_default_limit(uint64_t);
|
|
||||||
instantiate_default_limit(int8_t);
|
|
||||||
instantiate_default_limit(int16_t);
|
|
||||||
instantiate_default_limit(int32_t);
|
|
||||||
instantiate_default_limit(int64_t);
|
|
||||||
|
|
||||||
#define instantiate_float_limit(type) \
|
|
||||||
template <> \
|
|
||||||
struct Limits<type> { \
|
|
||||||
static constexpr constant type max = \
|
|
||||||
metal::numeric_limits<type>::infinity(); \
|
|
||||||
static constexpr constant type min = \
|
|
||||||
-metal::numeric_limits<type>::infinity(); \
|
|
||||||
static constexpr constant type finite_max = \
|
|
||||||
metal::numeric_limits<type>::max(); \
|
|
||||||
static constexpr constant type finite_min = \
|
|
||||||
-metal::numeric_limits<type>::max(); \
|
|
||||||
};
|
|
||||||
|
|
||||||
instantiate_float_limit(half);
|
|
||||||
instantiate_float_limit(float);
|
|
||||||
instantiate_float_limit(bfloat16_t);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct Limits<bool> {
|
|
||||||
static constexpr constant bool max = true;
|
|
||||||
static constexpr constant bool min = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Indexing utils
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
inline size_t elem_to_loc(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides,
|
|
||||||
int ndim) {
|
|
||||||
size_t loc = 0;
|
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
|
||||||
loc += (elem % shape[i]) * strides[i];
|
|
||||||
elem /= shape[i];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elem_to_loc(
|
|
||||||
uint elem,
|
|
||||||
constant const int* shape,
|
|
||||||
constant const size_t* strides,
|
|
||||||
int ndim) {
|
|
||||||
size_t loc = 0;
|
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
|
||||||
loc += (elem % shape[i]) * strides[i];
|
|
||||||
elem /= shape[i];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
inline uint2 elem_to_loc_2_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int shape[NDIM],
|
|
||||||
constant const size_t a_strides[NDIM],
|
|
||||||
constant const size_t b_strides[NDIM]) {
|
|
||||||
uint2 loc = {
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
uint l = elem.z % shape[d];
|
|
||||||
loc.x += l * a_strides[d];
|
|
||||||
loc.y += l * b_strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
inline size_t elem_to_loc_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int shape[NDIM],
|
|
||||||
constant const size_t strides[NDIM]) {
|
|
||||||
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
loc += (elem.z % shape[d]) * strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
|
||||||
return elem * stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
|
||||||
return elem.x * strides[1] + elem.y * strides[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
|
||||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non templated version to handle arbitrary dims
|
|
||||||
inline size_t elem_to_loc(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int* shape,
|
|
||||||
constant const size_t* strides,
|
|
||||||
int ndim) {
|
|
||||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
|
||||||
loc += (elem.z % shape[d]) * strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline uint2 elem_to_loc_2_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int* shape,
|
|
||||||
constant const size_t* a_strides,
|
|
||||||
constant const size_t* b_strides,
|
|
||||||
int ndim) {
|
|
||||||
uint2 loc = {
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
|
||||||
uint l = elem.z % shape[d];
|
|
||||||
loc.x += l * a_strides[d];
|
|
||||||
loc.y += l * b_strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
inline uint elem_to_loc_nd(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<1>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
return (elem % shape[0]) * strides[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<2>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
uint loc = (elem % shape[1]) * strides[1];
|
|
||||||
elem /= shape[1];
|
|
||||||
loc += (elem % shape[0]) * strides[0];
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<3>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
uint loc = (elem % shape[2]) * strides[2];
|
|
||||||
elem /= shape[2];
|
|
||||||
loc += (elem % shape[1]) * strides[1];
|
|
||||||
elem /= shape[1];
|
|
||||||
loc += (elem % shape[0]) * strides[0];
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<4>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
uint loc = (elem % shape[3]) * strides[3];
|
|
||||||
elem /= shape[3];
|
|
||||||
loc += (elem % shape[2]) * strides[2];
|
|
||||||
elem /= shape[2];
|
|
||||||
loc += (elem % shape[1]) * strides[1];
|
|
||||||
elem /= shape[1];
|
|
||||||
loc += (elem % shape[0]) * strides[0];
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Calculation utils
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/** Compute ceil((float)N/(float)M) */
|
|
||||||
inline size_t ceildiv(size_t N, size_t M) {
|
|
||||||
return (N + M - 1) / M;
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
|
||||||
inline float log1p(float x) {
|
|
||||||
float xp1 = 1.0f + x;
|
|
||||||
if (xp1 == Limits<float>::max) {
|
|
||||||
return Limits<float>::max;
|
|
||||||
}
|
|
||||||
if (xp1 == 1.0f) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bfloat16_t log1p(bfloat16_t x) {
|
|
||||||
float xp1 = 1.0f + static_cast<float>(x);
|
|
||||||
if (xp1 == Limits<float>::max) {
|
|
||||||
return Limits<bfloat16_t>::max;
|
|
||||||
}
|
|
||||||
if (xp1 == 1.0f) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// SIMD shuffle ops
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
|
||||||
return as_type<uint64_t>(
|
|
||||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
|
||||||
return as_type<int64_t>(
|
|
||||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
|
||||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
|
||||||
}
|
|
@ -1,9 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
#include "gemm/host.h"
|
|
||||||
|
|
||||||
#define STEEL_CONST static constant constexpr const
|
|
||||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
|
@ -1,7 +1,6 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize,
|
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||||
NSUInteger,
|
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@ -17,7 +16,6 @@ const CONV: &str = include_str!("conv.metal");
|
|||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &str = include_str!("random.metal");
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
const GEMM: &[u8] = include_bytes!("gemm/steel_gemm.metallib");
|
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||||
|
|
||||||
/// Most kernels apply similarly across the tensors
|
/// Most kernels apply similarly across the tensors
|
||||||
@ -124,7 +122,6 @@ pub enum Source {
|
|||||||
Cast,
|
Cast,
|
||||||
Reduce,
|
Reduce,
|
||||||
Mfa,
|
Mfa,
|
||||||
Gemm,
|
|
||||||
Conv,
|
Conv,
|
||||||
Random,
|
Random,
|
||||||
Quantized,
|
Quantized,
|
||||||
@ -186,7 +183,7 @@ macro_rules! ops{
|
|||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(
|
ops!(
|
||||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||||
tanh, recip
|
tanh, recip, silu
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
@ -251,7 +248,6 @@ impl Kernels {
|
|||||||
Source::Random => RANDOM,
|
Source::Random => RANDOM,
|
||||||
Source::Quantized => QUANTIZED,
|
Source::Quantized => QUANTIZED,
|
||||||
Source::Mfa => panic!("Invalid lib"),
|
Source::Mfa => panic!("Invalid lib"),
|
||||||
Source::Gemm => panic!("Invalid lib"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -275,14 +271,6 @@ impl Kernels {
|
|||||||
))
|
))
|
||||||
})?
|
})?
|
||||||
}
|
}
|
||||||
Source::Gemm => {
|
|
||||||
let source_data = GEMM;
|
|
||||||
device.new_library_with_data(source_data).map_err(|e| {
|
|
||||||
MetalKernelError::LoadLibraryError(format!(
|
|
||||||
"Candle metal requires macosx > 13.0 or higher, cannot load GEMM: {e}"
|
|
||||||
))
|
|
||||||
})?
|
|
||||||
}
|
|
||||||
source => {
|
source => {
|
||||||
let source_content = self.get_library_source(source);
|
let source_content = self.get_library_source(source);
|
||||||
device
|
device
|
||||||
@ -1242,34 +1230,6 @@ impl ConstantValues {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn string_to_static_str(s: String) -> &'static str {
|
|
||||||
Box::leak(s.into_boxed_str())
|
|
||||||
}
|
|
||||||
|
|
||||||
use core::ffi::c_int;
|
|
||||||
|
|
||||||
#[repr(C)]
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct GEMMParams {
|
|
||||||
m: c_int,
|
|
||||||
n: c_int,
|
|
||||||
k: c_int,
|
|
||||||
|
|
||||||
lda: c_int,
|
|
||||||
ldb: c_int,
|
|
||||||
ldc: c_int,
|
|
||||||
|
|
||||||
tiles_n: c_int,
|
|
||||||
tiles_m: c_int,
|
|
||||||
|
|
||||||
batch_stride_a: c_int,
|
|
||||||
batch_stride_b: c_int,
|
|
||||||
batch_stride_c: c_int,
|
|
||||||
|
|
||||||
swizzle_log: c_int,
|
|
||||||
gemm_k_iterations_aligned: c_int,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_gemm(
|
pub fn call_gemm(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -1291,10 +1251,10 @@ pub fn call_gemm(
|
|||||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
let (a_trans, lda) = if lhs_m1 == 1 && lhs_m2 == k {
|
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
(false, k as c_int)
|
false
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(true, n as c_int)
|
true
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
@ -1302,10 +1262,10 @@ pub fn call_gemm(
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?;
|
})?;
|
||||||
};
|
};
|
||||||
let (b_trans, ldb) = if rhs_m1 == 1 && rhs_m2 == n {
|
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
(false, n as c_int)
|
false
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(true, k as c_int)
|
true
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
@ -1313,195 +1273,119 @@ pub fn call_gemm(
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?;
|
})?;
|
||||||
};
|
};
|
||||||
// let d_trans = false;
|
let d_trans = false;
|
||||||
// let alpha = 1.0f32;
|
let alpha = 1.0f32;
|
||||||
// let beta = 0.0f32;
|
let beta = 0.0f32;
|
||||||
// let batched = b > 1;
|
let batched = b > 1;
|
||||||
// let fused_activation = false;
|
let fused_activation = false;
|
||||||
// let fused_bias = false;
|
let fused_bias = false;
|
||||||
// let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||||
// let m_simd = 8;
|
let m_simd = 8;
|
||||||
// let n_simd = 8;
|
let n_simd = 8;
|
||||||
// let k_simd = 64;
|
let k_simd = 64;
|
||||||
// let m_splits = 1;
|
let m_splits = 1;
|
||||||
// let n_splits = 1;
|
let n_splits = 1;
|
||||||
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||||
// } else {
|
} else {
|
||||||
// let m_simd = 40;
|
let m_simd = 40;
|
||||||
// let n_simd = 40;
|
let n_simd = 40;
|
||||||
// let k_simd = 32;
|
let k_simd = 32;
|
||||||
// let m_splits = 1;
|
let m_splits = 1;
|
||||||
// let n_splits = 1;
|
let n_splits = 1;
|
||||||
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||||
// };
|
};
|
||||||
// let constants = Some(ConstantValues::new(vec![
|
let constants = Some(ConstantValues::new(vec![
|
||||||
// (0, Value::USize(m)),
|
(0, Value::USize(m)),
|
||||||
// (1, Value::USize(n)),
|
(1, Value::USize(n)),
|
||||||
// (2, Value::USize(k)),
|
(2, Value::USize(k)),
|
||||||
// (10, Value::Bool(a_trans)),
|
(10, Value::Bool(a_trans)),
|
||||||
// (11, Value::Bool(b_trans)),
|
(11, Value::Bool(b_trans)),
|
||||||
// (13, Value::Bool(d_trans)),
|
(13, Value::Bool(d_trans)),
|
||||||
// (20, Value::F32(alpha)),
|
(20, Value::F32(alpha)),
|
||||||
// (21, Value::F32(beta)),
|
(21, Value::F32(beta)),
|
||||||
// (100, Value::Bool(batched)),
|
(100, Value::Bool(batched)),
|
||||||
// (101, Value::Bool(fused_activation)),
|
(101, Value::Bool(fused_activation)),
|
||||||
// // Garbage
|
// Garbage
|
||||||
// (102, Value::Bool(false)),
|
(102, Value::Bool(false)),
|
||||||
// (103, Value::Bool(false)),
|
(103, Value::Bool(false)),
|
||||||
// (113, Value::Bool(false)),
|
(113, Value::Bool(false)),
|
||||||
// (50_000, Value::Bool(false)),
|
(50_000, Value::Bool(false)),
|
||||||
// // End garbage
|
// End garbage
|
||||||
// (200, Value::U16(m_simd)),
|
(200, Value::U16(m_simd)),
|
||||||
// (201, Value::U16(n_simd)),
|
(201, Value::U16(n_simd)),
|
||||||
// (202, Value::U16(k_simd)),
|
(202, Value::U16(k_simd)),
|
||||||
// (210, Value::U16(m_splits)),
|
(210, Value::U16(m_splits)),
|
||||||
// (211, Value::U16(n_splits)),
|
(211, Value::U16(n_splits)),
|
||||||
// (50_001, Value::Bool(fused_bias)),
|
(50_001, Value::Bool(fused_bias)),
|
||||||
// ]));
|
]));
|
||||||
let a_trans_name = if a_trans { "t" } else { "n" };
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||||
let b_trans_name = if b_trans { "t" } else { "n" };
|
let m_group = m_simd * m_splits;
|
||||||
let (iname, oname) = match name {
|
let n_group = n_simd * n_splits;
|
||||||
"sgemm" => ("float32", "float32"),
|
|
||||||
"hgemm" => ("float16", "float16"),
|
let a_block_length = m_group * k_simd;
|
||||||
"bgemm" => ("bfloat16", "bfloat16"),
|
let b_block_length = k_simd * n_group;
|
||||||
|
|
||||||
|
let mut block_elements = a_block_length + b_block_length;
|
||||||
|
if (m % 8 != 0) && (n % 8 != 0) {
|
||||||
|
let c_block_length = m_group * n_group;
|
||||||
|
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||||
|
}
|
||||||
|
if fused_bias {
|
||||||
|
if d_trans {
|
||||||
|
block_elements = std::cmp::max(block_elements, m_group);
|
||||||
|
} else {
|
||||||
|
block_elements = std::cmp::max(block_elements, n_group);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let bytes = match name {
|
||||||
|
"sgemm" => 4,
|
||||||
|
"hgemm" => 2,
|
||||||
other => {
|
other => {
|
||||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||||
"{other} is not a valid kernel for gemm"
|
"{other} is not a valid kernel for gemm"
|
||||||
)))
|
)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let mut bm = 32;
|
let block_bytes = block_elements * bytes;
|
||||||
let mut bn = 32;
|
|
||||||
let mut bk = 16;
|
|
||||||
let wm = 2;
|
|
||||||
let wn = 2;
|
|
||||||
if b * m * n >= 1 << 20 {
|
|
||||||
if !a_trans && b_trans {
|
|
||||||
bm = 64;
|
|
||||||
bn = if oname == "float32" { 64 } else { 32 };
|
|
||||||
bk = if oname == "float32" { 16 } else { 32 };
|
|
||||||
} else {
|
|
||||||
bm = 64;
|
|
||||||
bn = 64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mnaligned = if m % bm == 0 && n % bn == 0 {
|
|
||||||
"taligned"
|
|
||||||
} else {
|
|
||||||
"naligned"
|
|
||||||
};
|
|
||||||
let kaligned = if k % bk == 0 { "taligned" } else { "naligned" };
|
|
||||||
// let bytes = match &name[..] {
|
|
||||||
// "sgemm" => 4,
|
|
||||||
// "hgemm" => 2,
|
|
||||||
// other => {
|
|
||||||
// return Err(MetalKernelError::LoadLibraryError(format!(
|
|
||||||
// "{other} is not a valid kernel for gemm"
|
|
||||||
// )));
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
let name = format!("steel_gemm_{a_trans_name}{b_trans_name}_{iname}_{oname}_bm{bm}_bn{bn}_bk{bk}_wm{wm}_wn{wn}_MN_{mnaligned}_K_{kaligned}");
|
|
||||||
let name = string_to_static_str(name);
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Gemm, name)?;
|
|
||||||
// let m_group = m_simd * m_splits;
|
|
||||||
// let n_group = n_simd * n_splits;
|
|
||||||
|
|
||||||
// let a_block_length = m_group * k_simd;
|
|
||||||
// let b_block_length = k_simd * n_group;
|
|
||||||
|
|
||||||
// let mut block_elements = a_block_length + b_block_length;
|
|
||||||
// if (m % 8 != 0) && (n % 8 != 0) {
|
|
||||||
// let c_block_length = m_group * n_group;
|
|
||||||
// block_elements = std::cmp::max(c_block_length, block_elements)
|
|
||||||
// }
|
|
||||||
// if fused_bias {
|
|
||||||
// if d_trans {
|
|
||||||
// block_elements = std::cmp::max(block_elements, m_group);
|
|
||||||
// } else {
|
|
||||||
// block_elements = std::cmp::max(block_elements, n_group);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// let block_bytes = block_elements * bytes;
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||||
|
|
||||||
let batch_stride_a: i32 = if lhs_stride.len() > 2 {
|
|
||||||
lhs_stride[lhs_stride.len() - 3] as i32
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
let batch_stride_b: i32 = if rhs_stride.len() > 2 {
|
|
||||||
rhs_stride[rhs_stride.len() - 3] as i32
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
let batch_stride_c = (m * n) as i32;
|
|
||||||
|
|
||||||
let swizzle_log = 0;
|
|
||||||
let tiles_n = ((n + bn - 1) / bn) as c_int;
|
|
||||||
let tiles_m = ((m + bm - 1) / bm) as c_int;
|
|
||||||
|
|
||||||
let params = GEMMParams {
|
|
||||||
m: m as c_int,
|
|
||||||
n: n as c_int,
|
|
||||||
k: k as c_int,
|
|
||||||
lda,
|
|
||||||
ldb,
|
|
||||||
ldc: n as c_int,
|
|
||||||
tiles_m,
|
|
||||||
tiles_n,
|
|
||||||
batch_stride_a,
|
|
||||||
batch_stride_b,
|
|
||||||
batch_stride_c,
|
|
||||||
swizzle_log,
|
|
||||||
gemm_k_iterations_aligned: (k / bk) as c_int,
|
|
||||||
};
|
|
||||||
let params_buffer = device.new_buffer_with_data(
|
|
||||||
¶ms as *const GEMMParams as *const c_void,
|
|
||||||
core::mem::size_of::<GEMMParams>() as u64,
|
|
||||||
MTLResourceOptions::StorageModeShared,
|
|
||||||
);
|
|
||||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||||
encoder.set_buffer(2, Some(output), 0);
|
encoder.set_buffer(2, Some(output), 0);
|
||||||
encoder.set_buffer(3, Some(¶ms_buffer), 0);
|
|
||||||
// TODO Tensor D
|
// TODO Tensor D
|
||||||
|
|
||||||
let grid_z = b;
|
let grid_z = b;
|
||||||
// if batched {
|
if batched {
|
||||||
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||||
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||||
// let byte_stride_c = m * n * bytes as usize;
|
let byte_stride_c = m * n * bytes as usize;
|
||||||
// // TODO byte_stride_d
|
// TODO byte_stride_d
|
||||||
// let byte_stride_d = 0;
|
let byte_stride_d = 0;
|
||||||
|
|
||||||
// let buffer: Vec<u64> = vec![
|
let buffer: Vec<u64> = vec![
|
||||||
// byte_stride_a as _,
|
byte_stride_a as _,
|
||||||
// byte_stride_b as _,
|
byte_stride_b as _,
|
||||||
// byte_stride_c as _,
|
byte_stride_c as _,
|
||||||
// byte_stride_d as _,
|
byte_stride_d as _,
|
||||||
// ];
|
];
|
||||||
// // encoder.set_bytes(
|
encoder.set_bytes(
|
||||||
// // 10,
|
10,
|
||||||
// // (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||||
// // buffer.as_ptr() as *const NSUInteger as *const c_void,
|
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||||
// // );
|
);
|
||||||
// }
|
}
|
||||||
let tile = 1 << swizzle_log;
|
|
||||||
let tm = (tiles_m + tile - 1) / tile;
|
|
||||||
let tn = tiles_n * tile;
|
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
let grid_size = MTLSize {
|
||||||
width: tn as u64,
|
width: divide(n, n_group.into()),
|
||||||
height: tm as u64,
|
height: divide(m, m_group.into()),
|
||||||
depth: grid_z as NSUInteger,
|
depth: grid_z as NSUInteger,
|
||||||
};
|
};
|
||||||
let group_size = MTLSize {
|
let group_size = MTLSize {
|
||||||
width: 32,
|
width: 32 * (m_splits as u64) * (n_splits as u64),
|
||||||
height: wn,
|
height: 1,
|
||||||
depth: wm,
|
depth: 1,
|
||||||
};
|
};
|
||||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
|
@ -231,6 +231,25 @@ fn gelu_f32() {
|
|||||||
assert_eq!(approx(results, 3), expected);
|
assert_eq!(approx(results, 3), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn silu_f16() {
|
||||||
|
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
||||||
|
.iter()
|
||||||
|
.map(|v| f16::from_f32(*v))
|
||||||
|
.collect();
|
||||||
|
let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0];
|
||||||
|
let results = run(&v, unary::contiguous::silu::HALF);
|
||||||
|
assert_eq!(approx_f16(results, 2), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn silu_f32() {
|
||||||
|
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
||||||
|
let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0];
|
||||||
|
let results = run(&v, unary::contiguous::silu::FLOAT);
|
||||||
|
assert_eq!(approx(results, 3), expected);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn binary_add_f32() {
|
fn binary_add_f32() {
|
||||||
let left = vec![1.0f32, 2.0, 3.0];
|
let left = vec![1.0f32, 2.0, 3.0];
|
||||||
|
@ -64,6 +64,9 @@ template <typename T> METAL_FUNC T relu(T in){
|
|||||||
}
|
}
|
||||||
return in;
|
return in;
|
||||||
}
|
}
|
||||||
|
template <typename T> METAL_FUNC T silu(T in){
|
||||||
|
return in / (static_cast<T>(1) + exp(-in));
|
||||||
|
}
|
||||||
|
|
||||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
@ -108,6 +111,7 @@ UNARY_OP(neg)
|
|||||||
UNARY_OP(exp)
|
UNARY_OP(exp)
|
||||||
UNARY_OP(log)
|
UNARY_OP(log)
|
||||||
UNARY_OP(gelu)
|
UNARY_OP(gelu)
|
||||||
|
UNARY_OP(silu)
|
||||||
UNARY_OP(abs)
|
UNARY_OP(abs)
|
||||||
UNARY_OP(ceil)
|
UNARY_OP(ceil)
|
||||||
UNARY_OP(floor)
|
UNARY_OP(floor)
|
||||||
@ -135,6 +139,7 @@ BFLOAT_UNARY_OP(neg)
|
|||||||
BFLOAT_UNARY_OP(exp)
|
BFLOAT_UNARY_OP(exp)
|
||||||
BFLOAT_UNARY_OP(log)
|
BFLOAT_UNARY_OP(log)
|
||||||
BFLOAT_UNARY_OP(gelu)
|
BFLOAT_UNARY_OP(gelu)
|
||||||
|
BFLOAT_UNARY_OP(silu)
|
||||||
BFLOAT_UNARY_OP(abs)
|
BFLOAT_UNARY_OP(abs)
|
||||||
BFLOAT_UNARY_OP(ceil)
|
BFLOAT_UNARY_OP(ceil)
|
||||||
BFLOAT_UNARY_OP(floor)
|
BFLOAT_UNARY_OP(floor)
|
||||||
|
@ -30,7 +30,7 @@ impl super::Module for Activation {
|
|||||||
Self::Relu => xs.relu(),
|
Self::Relu => xs.relu(),
|
||||||
Self::Relu2 => xs.relu()?.sqr(),
|
Self::Relu2 => xs.relu()?.sqr(),
|
||||||
Self::Relu6 => xs.clamp(0f32, 6f32),
|
Self::Relu6 => xs.clamp(0f32, 6f32),
|
||||||
Self::Silu => crate::ops::silu(xs),
|
Self::Silu => xs.silu(),
|
||||||
Self::Sigmoid => crate::ops::sigmoid(xs),
|
Self::Sigmoid => crate::ops::sigmoid(xs),
|
||||||
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
|
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
|
||||||
Self::Swiglu => crate::ops::swiglu(xs),
|
Self::Swiglu => crate::ops::swiglu(xs),
|
||||||
|
@ -35,13 +35,12 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: Should we have a specialized op for this?
|
xs.silu()
|
||||||
xs / (xs.neg()?.exp()? + 1.0)?
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||||
let xs = xs.chunk(2, candle::D::Minus1)?;
|
let xs = xs.chunk(2, candle::D::Minus1)?;
|
||||||
crate::ops::silu(&xs[0])? * &xs[1]
|
&xs[0].silu()? * &xs[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
@ -2,10 +2,16 @@
|
|||||||
//!
|
//!
|
||||||
//! See "A ConvNet for the 2020s" Liu et al. 2022
|
//! See "A ConvNet for the 2020s" Liu et al. 2022
|
||||||
//! <https://arxiv.org/abs/2201.03545>
|
//! <https://arxiv.org/abs/2201.03545>
|
||||||
|
//! and
|
||||||
|
//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023
|
||||||
|
//! <https://arxiv.org/abs/2301.00808>
|
||||||
|
|
||||||
//! Original code: https://github.com/facebookresearch/ConvNeXt/
|
//! Original code:
|
||||||
|
//! https://github.com/facebookresearch/ConvNeXt/
|
||||||
|
//! https://github.com/facebookresearch/ConvNeXt-V2/
|
||||||
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
|
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
|
||||||
|
|
||||||
|
use candle::shape::ShapeWithOneHole;
|
||||||
use candle::{Result, D};
|
use candle::{Result, D};
|
||||||
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
||||||
|
|
||||||
@ -13,31 +19,71 @@ use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
|||||||
pub struct Config {
|
pub struct Config {
|
||||||
blocks: [usize; 4],
|
blocks: [usize; 4],
|
||||||
channels: [usize; 4],
|
channels: [usize; 4],
|
||||||
|
use_conv_mlp: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
pub fn atto() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [2, 2, 6, 2],
|
||||||
|
channels: [40, 80, 160, 320],
|
||||||
|
use_conv_mlp: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn femto() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [2, 2, 6, 2],
|
||||||
|
channels: [48, 96, 192, 384],
|
||||||
|
use_conv_mlp: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pico() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [2, 2, 6, 2],
|
||||||
|
channels: [64, 128, 256, 512],
|
||||||
|
use_conv_mlp: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn nano() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [2, 2, 8, 2],
|
||||||
|
channels: [80, 160, 320, 640],
|
||||||
|
use_conv_mlp: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tiny() -> Self {
|
pub fn tiny() -> Self {
|
||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 9, 3],
|
blocks: [3, 3, 9, 3],
|
||||||
channels: [96, 192, 384, 768],
|
channels: [96, 192, 384, 768],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn small() -> Self {
|
pub fn small() -> Self {
|
||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 27, 3],
|
blocks: [3, 3, 27, 3],
|
||||||
channels: [96, 192, 384, 768],
|
channels: [96, 192, 384, 768],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn base() -> Self {
|
pub fn base() -> Self {
|
||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 27, 3],
|
blocks: [3, 3, 27, 3],
|
||||||
channels: [128, 256, 512, 1024],
|
channels: [128, 256, 512, 1024],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn large() -> Self {
|
pub fn large() -> Self {
|
||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 27, 3],
|
blocks: [3, 3, 27, 3],
|
||||||
channels: [192, 384, 768, 1536],
|
channels: [192, 384, 768, 1536],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,8 +91,68 @@ impl Config {
|
|||||||
Self {
|
Self {
|
||||||
blocks: [3, 3, 27, 3],
|
blocks: [3, 3, 27, 3],
|
||||||
channels: [256, 512, 1024, 2048],
|
channels: [256, 512, 1024, 2048],
|
||||||
|
use_conv_mlp: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn huge() -> Self {
|
||||||
|
Self {
|
||||||
|
blocks: [3, 3, 27, 3],
|
||||||
|
channels: [352, 704, 1408, 2816],
|
||||||
|
use_conv_mlp: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer norm for data in channels-last format.
|
||||||
|
fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| xs.apply(&norm)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer norm for data in channels-first format.
|
||||||
|
fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let xs = xs
|
||||||
|
.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&norm)?
|
||||||
|
.permute((0, 3, 1, 2))?;
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Global response normalization layer
|
||||||
|
// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py
|
||||||
|
fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let (shape, spatial_dim, channel_dim) = if channels_last {
|
||||||
|
((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
|
||||||
|
} else {
|
||||||
|
((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
|
||||||
|
};
|
||||||
|
|
||||||
|
let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
|
||||||
|
let beta = vb.get(dim, "bias")?.reshape(&shape)?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let residual = xs;
|
||||||
|
let gx = xs
|
||||||
|
.sqr()?
|
||||||
|
.sum_keepdim(spatial_dim)?
|
||||||
|
.mean_keepdim(spatial_dim)?
|
||||||
|
.sqrt()?;
|
||||||
|
|
||||||
|
let gxmean = gx.mean_keepdim(channel_dim)?;
|
||||||
|
let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
|
||||||
|
let xs = xs
|
||||||
|
.broadcast_mul(&nx)?
|
||||||
|
.broadcast_mul(&gamma)?
|
||||||
|
.broadcast_add(&beta)?;
|
||||||
|
|
||||||
|
xs + residual
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initial downsampling via a patchify layer.
|
// Initial downsampling via a patchify layer.
|
||||||
@ -56,16 +162,9 @@ fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
|
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
|
||||||
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?;
|
let norm = layer_norm_cf(out_channels, vb.pp(1))?;
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
// The layer norm works with channels-last format.
|
Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
|
||||||
let xs = xs
|
|
||||||
.apply(&patchify)?
|
|
||||||
.permute((0, 2, 3, 1))?
|
|
||||||
.apply(&norm)?
|
|
||||||
.permute((0, 3, 1, 2))?;
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Downsampling applied after the stages.
|
// Downsampling applied after the stages.
|
||||||
@ -74,31 +173,49 @@ fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|||||||
stride: 2,
|
stride: 2,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?;
|
let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
|
||||||
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
|
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let xs = xs
|
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
|
||||||
.permute((0, 2, 3, 1))?
|
|
||||||
.apply(&norm)?
|
|
||||||
.permute((0, 3, 1, 2))?
|
|
||||||
.apply(&conv)?;
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MLP equivalent of pointwise convolutions.
|
// MLP block from the original paper with optional GRN layer (v2 models).
|
||||||
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
|
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
|
||||||
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
|
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
|
||||||
|
let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
|
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||||
|
if let Ok(g) = &grn {
|
||||||
|
xs = xs.apply(g)?;
|
||||||
|
}
|
||||||
|
xs = xs.apply(&fc2)?;
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// A block consisting of a depthwise convolution, a MLP and layer scaling.
|
// MLP block using pointwise convolutions, with optional GRN layer (v2 models).
|
||||||
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let conv2d_cfg = Conv2dConfig {
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||||
|
let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||||
|
|
||||||
|
let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||||
|
if let Ok(g) = &grn {
|
||||||
|
xs = xs.apply(g)?;
|
||||||
|
}
|
||||||
|
xs = xs.apply(&fc2)?;
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only).
|
||||||
|
fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
let conv2d_cfg = Conv2dConfig {
|
let conv2d_cfg = Conv2dConfig {
|
||||||
groups: dim,
|
groups: dim,
|
||||||
padding: 3,
|
padding: 3,
|
||||||
@ -106,20 +223,36 @@ fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
|
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
|
||||||
|
let gamma = vb.get(dim, "gamma");
|
||||||
|
|
||||||
let gamma = vb.get(dim, "gamma")?;
|
let (mlp, norm) = if use_conv_mlp {
|
||||||
let mlp = convnext_mlp(dim, vb.pp("mlp"))?;
|
(
|
||||||
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
|
convnext_conv_mlp(dim, vb.pp("mlp"))?,
|
||||||
|
layer_norm_cf(dim, vb.pp("norm"))?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
convnext_mlp(dim, vb.pp("mlp"))?,
|
||||||
|
layer_norm_cl(dim, vb.pp("norm"))?,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let xs = xs
|
let mut xs = xs.apply(&conv_dw)?;
|
||||||
.apply(&conv_dw)?
|
|
||||||
.permute((0, 2, 3, 1))?
|
xs = if use_conv_mlp {
|
||||||
|
xs.apply(&norm)?.apply(&mlp)?
|
||||||
|
} else {
|
||||||
|
xs.permute((0, 2, 3, 1))?
|
||||||
.apply(&norm)?
|
.apply(&norm)?
|
||||||
.apply(&mlp)?
|
.apply(&mlp)?
|
||||||
.broadcast_mul(&gamma)?
|
.permute((0, 3, 1, 2))?
|
||||||
.permute((0, 3, 1, 2))?;
|
};
|
||||||
|
|
||||||
|
if let Ok(g) = &gamma {
|
||||||
|
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
|
||||||
|
};
|
||||||
|
|
||||||
xs + residual
|
xs + residual
|
||||||
}))
|
}))
|
||||||
@ -137,7 +270,11 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
|
|||||||
}
|
}
|
||||||
|
|
||||||
for block_idx in 0..nblocks {
|
for block_idx in 0..nblocks {
|
||||||
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?);
|
blocks.push(convnext_block(
|
||||||
|
dim,
|
||||||
|
cfg.use_conv_mlp,
|
||||||
|
vb.pp(format!("blocks.{block_idx}")),
|
||||||
|
)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
@ -149,8 +286,9 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Classification head.
|
||||||
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||||
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
|
let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
|
||||||
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
|
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
|
||||||
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
|
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 4096;
|
pub const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
pub struct LlamaConfig {
|
pub struct LlamaConfig {
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
|
@ -34,6 +34,7 @@ pub mod quantized_t5;
|
|||||||
pub mod qwen2;
|
pub mod qwen2;
|
||||||
pub mod repvgg;
|
pub mod repvgg;
|
||||||
pub mod resnet;
|
pub mod resnet;
|
||||||
|
pub mod rwkv_v5;
|
||||||
pub mod segment_anything;
|
pub mod segment_anything;
|
||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod stable_lm;
|
pub mod stable_lm;
|
||||||
@ -41,6 +42,7 @@ pub mod t5;
|
|||||||
pub mod trocr;
|
pub mod trocr;
|
||||||
pub mod vgg;
|
pub mod vgg;
|
||||||
pub mod vit;
|
pub mod vit;
|
||||||
|
pub mod vocos;
|
||||||
pub mod whisper;
|
pub mod whisper;
|
||||||
pub mod with_tracing;
|
pub mod with_tracing;
|
||||||
pub mod wuerstchen;
|
pub mod wuerstchen;
|
||||||
|
409
candle-transformers/src/models/rwkv_v5.rs
Normal file
409
candle-transformers/src/models/rwkv_v5.rs
Normal file
@ -0,0 +1,409 @@
|
|||||||
|
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
|
||||||
|
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||||
|
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
|
fn default_num_attention_heads() -> usize {
|
||||||
|
64
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://huggingface.co/RWKV/HF_v5-Eagle-7B/blob/main/configuration_rwkv5.py
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub attention_hidden_size: usize,
|
||||||
|
#[serde(default = "default_num_attention_heads")]
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub head_size: usize,
|
||||||
|
pub intermediate_size: Option<usize>,
|
||||||
|
pub layer_norm_epsilon: f64,
|
||||||
|
pub rescale_every: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct StatePerLayer {
|
||||||
|
extract_key_value: Tensor,
|
||||||
|
linear_attention: Tensor,
|
||||||
|
feed_forward: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct State {
|
||||||
|
per_layer: Vec<StatePerLayer>,
|
||||||
|
pos: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
// Certainly a weird convention but taken from modeling_rwkv5.py
|
||||||
|
let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
|
for _layer_idx in 0..cfg.num_hidden_layers {
|
||||||
|
let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
|
||||||
|
let linear_attention = Tensor::zeros(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
num_attention_heads,
|
||||||
|
cfg.hidden_size / num_attention_heads,
|
||||||
|
cfg.hidden_size / num_attention_heads,
|
||||||
|
),
|
||||||
|
DType::F32,
|
||||||
|
dev,
|
||||||
|
)?;
|
||||||
|
let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
|
||||||
|
per_layer.push(StatePerLayer {
|
||||||
|
extract_key_value,
|
||||||
|
linear_attention,
|
||||||
|
feed_forward,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(Self { per_layer, pos: 0 })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct SelfAttention {
|
||||||
|
key: Linear,
|
||||||
|
receptance: Linear,
|
||||||
|
value: Linear,
|
||||||
|
gate: Linear,
|
||||||
|
output: Linear,
|
||||||
|
ln_x: candle_nn::GroupNorm,
|
||||||
|
time_mix_key: Tensor,
|
||||||
|
time_mix_value: Tensor,
|
||||||
|
time_mix_receptance: Tensor,
|
||||||
|
time_decay: Tensor,
|
||||||
|
time_faaaa: Tensor,
|
||||||
|
time_mix_gate: Tensor,
|
||||||
|
layer_id: usize,
|
||||||
|
n_attn_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SelfAttention {
|
||||||
|
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_size = cfg.hidden_size;
|
||||||
|
let attn_hidden_size = cfg.attention_hidden_size;
|
||||||
|
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
|
||||||
|
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
|
||||||
|
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
|
||||||
|
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
|
||||||
|
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
|
||||||
|
let ln_x = candle_nn::group_norm(
|
||||||
|
hidden_size / cfg.head_size,
|
||||||
|
hidden_size,
|
||||||
|
1e-5,
|
||||||
|
vb.pp("ln_x"),
|
||||||
|
)?;
|
||||||
|
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||||
|
let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
|
||||||
|
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||||
|
let n_attn_heads = cfg.hidden_size / cfg.head_size;
|
||||||
|
let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?;
|
||||||
|
let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
|
||||||
|
let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
|
||||||
|
Ok(Self {
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
receptance,
|
||||||
|
gate,
|
||||||
|
output,
|
||||||
|
ln_x,
|
||||||
|
time_mix_key,
|
||||||
|
time_mix_value,
|
||||||
|
time_mix_receptance,
|
||||||
|
time_decay,
|
||||||
|
time_faaaa,
|
||||||
|
time_mix_gate,
|
||||||
|
layer_id,
|
||||||
|
n_attn_heads,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let h = self.time_decay.dim(0)?;
|
||||||
|
let (b, t, s) = xs.dims3()?;
|
||||||
|
let s = s / h;
|
||||||
|
let (receptance, key, value, gate) = {
|
||||||
|
// exctract key-value
|
||||||
|
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
|
||||||
|
let shifted = if shifted.rank() == 2 {
|
||||||
|
shifted.unsqueeze(1)?
|
||||||
|
} else {
|
||||||
|
shifted
|
||||||
|
};
|
||||||
|
let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
|
||||||
|
let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
|
||||||
|
let receptance = ((xs * &self.time_mix_receptance)?
|
||||||
|
+ &shifted * (1.0 - &self.time_mix_receptance)?)?;
|
||||||
|
let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
|
||||||
|
|
||||||
|
let key = self.key.forward(&key)?;
|
||||||
|
let value = self.value.forward(&value)?;
|
||||||
|
let receptance = self.receptance.forward(&receptance)?;
|
||||||
|
let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
|
||||||
|
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
|
||||||
|
(receptance, key, value, gate)
|
||||||
|
};
|
||||||
|
// linear attention
|
||||||
|
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
|
||||||
|
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
|
||||||
|
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||||
|
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let time_decay = self
|
||||||
|
.time_decay
|
||||||
|
.exp()?
|
||||||
|
.neg()?
|
||||||
|
.exp()?
|
||||||
|
.reshape(((), 1, 1))?
|
||||||
|
.reshape((self.n_attn_heads, (), 1))?;
|
||||||
|
let time_faaaa =
|
||||||
|
self.time_faaaa
|
||||||
|
.reshape(((), 1, 1))?
|
||||||
|
.reshape((self.n_attn_heads, (), 1))?;
|
||||||
|
|
||||||
|
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||||
|
for t_ in 0..t {
|
||||||
|
//
|
||||||
|
let rt = receptance.i((.., .., t_..t_ + 1))?;
|
||||||
|
let kt = key.i((.., .., .., t_..t_ + 1))?;
|
||||||
|
let vt = value.i((.., .., t_..t_ + 1))?;
|
||||||
|
let at = kt.matmul(&vt)?;
|
||||||
|
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||||
|
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||||
|
state_ = (&at + time_decay.broadcast_mul(&state_))?;
|
||||||
|
out.push(out_)
|
||||||
|
}
|
||||||
|
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
|
||||||
|
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
|
||||||
|
let out = (out * gate)?.apply(&self.output)?;
|
||||||
|
state.per_layer[self.layer_id].linear_attention = state_;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct FeedForward {
|
||||||
|
time_mix_key: Tensor,
|
||||||
|
time_mix_receptance: Tensor,
|
||||||
|
key: Linear,
|
||||||
|
receptance: Linear,
|
||||||
|
value: Linear,
|
||||||
|
layer_id: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FeedForward {
|
||||||
|
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let int_size = cfg
|
||||||
|
.intermediate_size
|
||||||
|
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
|
||||||
|
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
|
||||||
|
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
|
||||||
|
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
|
||||||
|
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||||
|
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||||
|
Ok(Self {
|
||||||
|
key,
|
||||||
|
receptance,
|
||||||
|
value,
|
||||||
|
time_mix_key,
|
||||||
|
time_mix_receptance,
|
||||||
|
layer_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let shifted = &state.per_layer[self.layer_id].feed_forward;
|
||||||
|
let key = (xs.broadcast_mul(&self.time_mix_key)?
|
||||||
|
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
|
||||||
|
let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
|
||||||
|
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
|
||||||
|
let key = key.apply(&self.key)?.relu()?.sqr()?;
|
||||||
|
let value = key.apply(&self.value)?;
|
||||||
|
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
|
||||||
|
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
|
||||||
|
let xs = (receptance * value)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Block {
|
||||||
|
pre_ln: Option<LayerNorm>,
|
||||||
|
ln1: LayerNorm,
|
||||||
|
ln2: LayerNorm,
|
||||||
|
attention: SelfAttention,
|
||||||
|
feed_forward: FeedForward,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Block {
|
||||||
|
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
|
||||||
|
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
|
||||||
|
let pre_ln = if layer_id == 0 {
|
||||||
|
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
|
||||||
|
Some(ln)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
|
||||||
|
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
|
||||||
|
Ok(Self {
|
||||||
|
pre_ln,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
attention,
|
||||||
|
feed_forward,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let xs = match self.pre_ln.as_ref() {
|
||||||
|
None => xs.clone(),
|
||||||
|
Some(pre_ln) => xs.apply(pre_ln)?,
|
||||||
|
};
|
||||||
|
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
|
||||||
|
let xs = (xs + attention)?;
|
||||||
|
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
|
||||||
|
let xs = (xs + feed_forward)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embeddings: Embedding,
|
||||||
|
blocks: Vec<Block>,
|
||||||
|
ln_out: LayerNorm,
|
||||||
|
head: Linear,
|
||||||
|
rescale_every: usize,
|
||||||
|
layers_are_rescaled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_m = vb.pp("rwkv");
|
||||||
|
let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
|
||||||
|
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_b = vb_m.pp("blocks");
|
||||||
|
for block_index in 0..cfg.num_hidden_layers {
|
||||||
|
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
|
||||||
|
blocks.push(block)
|
||||||
|
}
|
||||||
|
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
|
||||||
|
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
blocks,
|
||||||
|
ln_out,
|
||||||
|
head,
|
||||||
|
rescale_every: cfg.rescale_every,
|
||||||
|
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let (_b_size, _seq_len) = xs.dims2()?;
|
||||||
|
let mut xs = xs.apply(&self.embeddings)?;
|
||||||
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
|
xs = block.forward(&xs, state)?;
|
||||||
|
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
|
||||||
|
xs = (xs / 2.)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
|
||||||
|
state.pos += 1;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Bytes = Vec<u8>;
|
||||||
|
|
||||||
|
// https://github.com/BlinkDL/ChatRWKV/blob/095e812aef15a1f74107f6c39d13578a2412dc46/RWKV_v5_demo.py#L14
|
||||||
|
pub struct Tokenizer {
|
||||||
|
table: Vec<Vec<Vec<Bytes>>>,
|
||||||
|
good: Vec<HashSet<u8>>,
|
||||||
|
idx2token: HashMap<u32, Vec<u8>>,
|
||||||
|
token2idx: HashMap<Vec<u8>, u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tokenizer {
|
||||||
|
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||||
|
let file = std::fs::File::open(p)?;
|
||||||
|
let token2idx: HashMap<String, u32> =
|
||||||
|
serde_json::from_reader(file).map_err(candle::Error::wrap)?;
|
||||||
|
let token2idx = token2idx
|
||||||
|
.into_iter()
|
||||||
|
.map(|(key, value)| (key.into_bytes(), value))
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
let idx2token = token2idx
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| (*value, key.to_vec()))
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
|
||||||
|
let max_idx = token2idx.values().copied().max().unwrap_or(0);
|
||||||
|
|
||||||
|
let mut table = vec![vec![vec![]; 256]; 256];
|
||||||
|
let mut good = vec![HashSet::new(); 256];
|
||||||
|
for idx in (0..(1 + max_idx)).rev() {
|
||||||
|
let s = match idx2token.get(&idx) {
|
||||||
|
None => continue,
|
||||||
|
Some(s) => s,
|
||||||
|
};
|
||||||
|
if s.len() >= 2 {
|
||||||
|
let (s0, s1) = (s[0], s[1]);
|
||||||
|
table[s0 as usize][s1 as usize].push(s.to_vec());
|
||||||
|
good[s0 as usize].insert(s1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
table,
|
||||||
|
good,
|
||||||
|
idx2token,
|
||||||
|
token2idx,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
|
||||||
|
let mut v = Vec::new();
|
||||||
|
for token_id in tokens.iter() {
|
||||||
|
if let Some(token) = self.idx2token.get(token_id) {
|
||||||
|
v.extend_from_slice(token.as_slice())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||||
|
let bytes = self.decode_bytes(tokens);
|
||||||
|
String::from_utf8(bytes).map_err(candle::Error::wrap)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode_bytes(&self, bytes: &[u8]) -> Result<Vec<u32>> {
|
||||||
|
let mut tokens = Vec::new();
|
||||||
|
let mut i = 0;
|
||||||
|
while i < bytes.len() {
|
||||||
|
let mut s = vec![bytes[i]];
|
||||||
|
if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) {
|
||||||
|
let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize];
|
||||||
|
for table_elem in table.iter() {
|
||||||
|
if bytes[i..].starts_with(table_elem) {
|
||||||
|
s = table_elem.to_vec();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += s.len();
|
||||||
|
let token = match self.token2idx.get(&s) {
|
||||||
|
None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)),
|
||||||
|
Some(token) => *token,
|
||||||
|
};
|
||||||
|
tokens.push(token)
|
||||||
|
}
|
||||||
|
Ok(tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode(&self, str: &str) -> Result<Vec<u32>> {
|
||||||
|
self.encode_bytes(str.as_bytes())
|
||||||
|
}
|
||||||
|
}
|
156
candle-transformers/src/models/vocos.rs
Normal file
156
candle-transformers/src/models/vocos.rs
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
|
||||||
|
|
||||||
|
pub struct AdaLayerNorm {
|
||||||
|
eps: f64,
|
||||||
|
dim: usize,
|
||||||
|
scale: Embedding,
|
||||||
|
shift: Embedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
|
||||||
|
let x_dtype = x.dtype();
|
||||||
|
let internal_dtype = match x_dtype {
|
||||||
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
d => d,
|
||||||
|
};
|
||||||
|
let hidden_size = x.dim(D::Minus1)?;
|
||||||
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
|
let x = {
|
||||||
|
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
|
x.broadcast_sub(&mean_x)?
|
||||||
|
};
|
||||||
|
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
|
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
|
||||||
|
x_normed.to_dtype(x_dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AdaLayerNorm {
|
||||||
|
pub fn new(
|
||||||
|
num_embeddings: usize,
|
||||||
|
embedding_dim: usize,
|
||||||
|
eps: f64,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
|
||||||
|
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
|
||||||
|
Ok(Self {
|
||||||
|
eps,
|
||||||
|
dim: embedding_dim,
|
||||||
|
scale,
|
||||||
|
shift,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
|
||||||
|
let scale = self.scale.forward(cond_embedding_id)?;
|
||||||
|
let shift = self.shift.forward(cond_embedding_id)?;
|
||||||
|
let xs = layer_norm(xs, self.eps)?;
|
||||||
|
xs * scale + shift
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ConvNeXtBlock {
|
||||||
|
dwconv: Conv1d,
|
||||||
|
pwconv1: Linear,
|
||||||
|
pwconv2: Linear,
|
||||||
|
gamma: Option<Tensor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConvNeXtBlock {
|
||||||
|
pub fn new(
|
||||||
|
dim: usize,
|
||||||
|
intermediate_dim: usize,
|
||||||
|
layer_scale_init_value: f64,
|
||||||
|
adanorm_num_embeddings: Option<usize>,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let dwconv = {
|
||||||
|
let cfg = Conv1dConfig {
|
||||||
|
padding: 3,
|
||||||
|
groups: dim,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
|
||||||
|
};
|
||||||
|
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
|
||||||
|
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
|
||||||
|
let gamma = if layer_scale_init_value > 0. {
|
||||||
|
Some(vb.get(dim, "gamma")?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
dwconv,
|
||||||
|
pwconv1,
|
||||||
|
pwconv2,
|
||||||
|
gamma,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
|
||||||
|
// TODO: norm
|
||||||
|
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
|
||||||
|
let xs = match self.gamma.as_ref() {
|
||||||
|
Some(gamma) => (gamma * xs)?,
|
||||||
|
None => xs,
|
||||||
|
};
|
||||||
|
xs.transpose(1, 2)? + residual
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct VocosBackbone {
|
||||||
|
embed: Conv1d,
|
||||||
|
convnext: Vec<ConvNeXtBlock>,
|
||||||
|
final_layer_norm: candle_nn::LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VocosBackbone {
|
||||||
|
pub fn new(
|
||||||
|
input_channels: usize,
|
||||||
|
dim: usize,
|
||||||
|
intermediate_dim: usize,
|
||||||
|
num_layers: dim,
|
||||||
|
layer_scale_init_value: f64,
|
||||||
|
adanorm_num_embeddings: Option<usize>,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let embed = {
|
||||||
|
let cfg = Conv1dConfig {
|
||||||
|
padding: 3,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
|
||||||
|
};
|
||||||
|
let mut convnext = Vec::with_capacity(num_layers);
|
||||||
|
let vb_c = vb.pp("convnext");
|
||||||
|
for i in 0..num_layers {
|
||||||
|
let block = ConvNeXtBlock::new(
|
||||||
|
dim,
|
||||||
|
intermediate_dim,
|
||||||
|
layer_scale_init_value,
|
||||||
|
adanorm_num_embeddings,
|
||||||
|
vb_c.pp(i),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embed,
|
||||||
|
convnext,
|
||||||
|
final_layer_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = xs.apply(&self.embed)?;
|
||||||
|
// TODO: norm
|
||||||
|
let mut xs = xs.transpose(1, 2)?;
|
||||||
|
for conv_block in self.convnext.iter() {
|
||||||
|
xs = conv_block.forward(&xs)?
|
||||||
|
}
|
||||||
|
xs.apply(&self.final_layer_norm)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user