mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Compare commits
7 Commits
bf16_metal
...
vocos
Author | SHA1 | Date | |
---|---|---|---|
3f3730b657 | |||
058a910d0e | |||
26fe162ab5 | |||
121a71e01f | |||
2d5f2a728d | |||
68f7655895 | |||
b60064780d |
@ -75,6 +75,9 @@ We also provide a some command line based examples using state of the art models
|
||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||
much faster inference.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
||||
- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||
performance.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||
@ -193,6 +196,8 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Qwen1.5.
|
||||
- RWKV.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
@ -210,7 +215,8 @@ If you have an addition to this list, please submit a pull request.
|
||||
- BLIP.
|
||||
- TrOCR.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||
ConvNeXTv2.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
|
@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -589,6 +589,13 @@ impl Tensor {
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -679,6 +679,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
@ -696,6 +697,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
@ -730,6 +732,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
@ -745,6 +748,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("usilu", DType::F16) => strided::silu::HALF,
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
|
@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -61,6 +61,7 @@ pub enum UnaryOp {
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Silu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
@ -390,6 +391,7 @@ pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Silu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
@ -724,6 +726,77 @@ impl UnaryOpT for Erf {
|
||||
}
|
||||
}
|
||||
|
||||
/// Silu operation
|
||||
impl UnaryOpT for Silu {
|
||||
const NAME: &'static str = "silu";
|
||||
const V: Self = Silu;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v / (bf16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v / (f16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "usilu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::vd_silu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
|
@ -508,6 +508,7 @@ impl Tensor {
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(silu, Silu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
|
@ -270,6 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.Silu()
|
||||
let y = x.silu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.8577, 0.7311, 3.9281, 0.0806]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
@ -120,6 +120,13 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
||||
[
|
||||
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
||||
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
|
@ -1,6 +1,7 @@
|
||||
# candle-convnext
|
||||
|
||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
|
||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
|
||||
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
|
||||
|
||||
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
|
@ -12,38 +12,62 @@ use candle_transformers::models::convnext;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Atto,
|
||||
Femto,
|
||||
Pico,
|
||||
Nano,
|
||||
Tiny,
|
||||
Small,
|
||||
Base,
|
||||
Large,
|
||||
AttoV2,
|
||||
FemtoV2,
|
||||
PicoV2,
|
||||
NanoV2,
|
||||
TinyV2,
|
||||
BaseV2,
|
||||
LargeV2,
|
||||
XLarge,
|
||||
Huge,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::Tiny => "tiny",
|
||||
Self::Small => "small",
|
||||
Self::Base => "base",
|
||||
Self::Large => "large",
|
||||
Self::XLarge => "xlarge",
|
||||
};
|
||||
// The XLarge model only has an ImageNet-22K variant
|
||||
let variant = match self {
|
||||
Self::XLarge => "fb_in22k_ft_in1k",
|
||||
_ => "fb_in1k",
|
||||
Self::Atto => "convnext_atto.d2_in1k",
|
||||
Self::Femto => "convnext_femto.d1_in1k",
|
||||
Self::Pico => "convnext_pico.d1_in1k",
|
||||
Self::Nano => "convnext_nano.d1h_in1k",
|
||||
Self::Tiny => "convnext_tiny.fb_in1k",
|
||||
Self::Small => "convnext_small.fb_in1k",
|
||||
Self::Base => "convnext_base.fb_in1k",
|
||||
Self::Large => "convnext_large.fb_in1k",
|
||||
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
|
||||
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
|
||||
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
|
||||
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
|
||||
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
|
||||
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
|
||||
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
|
||||
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
|
||||
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
|
||||
};
|
||||
|
||||
format!("timm/convnext_{name}.{variant}")
|
||||
format!("timm/{name}")
|
||||
}
|
||||
|
||||
fn config(&self) -> convnext::Config {
|
||||
match self {
|
||||
Self::Tiny => convnext::Config::tiny(),
|
||||
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
|
||||
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
|
||||
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
|
||||
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
|
||||
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
|
||||
Self::Small => convnext::Config::small(),
|
||||
Self::Base => convnext::Config::base(),
|
||||
Self::Large => convnext::Config::large(),
|
||||
Self::Base | Self::BaseV2 => convnext::Config::base(),
|
||||
Self::Large | Self::LargeV2 => convnext::Config::large(),
|
||||
Self::XLarge => convnext::Config::xlarge(),
|
||||
Self::Huge => convnext::Config::huge(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
17
candle-examples/examples/rwkv/README.md
Normal file
17
candle-examples/examples/rwkv/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
## candle-rwkv
|
||||
|
||||
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
|
||||
with performance on par with transformer architectures. Several variants are
|
||||
available, candle implements the v5 version and can be used with Eagle 7B([blog
|
||||
post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||
|
||||
```bash
|
||||
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
The smallest prime is ϕ(2) = 2.
|
||||
The smallest composite is ϕ(3) = 3.
|
||||
The smallest perfect number is ϕ(5) = 5.
|
||||
The smallest perfect square is ϕ(4) = 4.
|
||||
The smallest perfect cube is ϕ(6) = 6.
|
||||
```
|
265
candle-examples/examples/rwkv/main.rs
Normal file
265
candle-examples/examples/rwkv/main.rs
Normal file
@ -0,0 +1,265 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
config: Config,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
let mut tokens = self.tokenizer.encode(prompt)?;
|
||||
let mut generated_tokens = 0usize;
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[[t]], &self.device)?;
|
||||
let logits = self.model.forward(&input, &mut state)?;
|
||||
next_logits = Some(logits);
|
||||
print!("{}", self.tokenizer.decode(&[t])?)
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for _ in 0..sample_len {
|
||||
let logits = match next_logits.as_ref() {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
print!("{}", self.tokenizer.decode(&[next_token])?);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let input = Tensor::new(&[[next_token]], &self.device)?;
|
||||
next_logits = Some(self.model.forward(&input, &mut state)?)
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||
enum Which {
|
||||
Eagle7b,
|
||||
World1b5,
|
||||
World3b,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||
}
|
||||
}
|
||||
|
||||
fn revision(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Eagle7b => "refs/pr/1",
|
||||
Self::World1b5 | Self::World3b => "refs/pr/2",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "world1b5")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id
|
||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||
RepoType::Model,
|
||||
args.revision
|
||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||
));
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("rwkv_vocab_v20230424.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::new(tokenizer)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -55,6 +55,11 @@ __device__ __forceinline__ T relu_fwd(T x) {
|
||||
return maxg(x, zero);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu_fwd(T x) {
|
||||
return x / (static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
@ -103,6 +108,7 @@ UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
#endif
|
||||
|
||||
@ -127,6 +133,7 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
#endif
|
||||
|
||||
@ -173,5 +180,7 @@ UNARY_OP(float, urelu_f32, relu_fwd(x))
|
||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
||||
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
|
||||
UNARY_OP(float, usilu_f32, silu_fwd(x))
|
||||
UNARY_OP(double, usilu_f64, silu_fwd(x))
|
||||
UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
|
@ -1,2 +0,0 @@
|
||||
xcrun metal -c src/gemm/kernels/steel_gemm.metal -I src/
|
||||
xcrun metallib steel_gemm.air -o src/gemm/steel_gemm.metallib
|
@ -1,317 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
#else
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
||||
// Check for nan
|
||||
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
||||
_fp_encoding_traits<float>::inf_mask) {
|
||||
return uint16_t(as_type<uint32_t>(0x7FC0));
|
||||
}
|
||||
// Take bits
|
||||
uint32_t float_bits = as_type<uint32_t>(x);
|
||||
|
||||
// Round to nearest even
|
||||
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
||||
|
||||
// Take upper 16 bits
|
||||
return float_bits >> 16;
|
||||
}
|
||||
|
||||
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||
return as_type<float>((uint32_t)x << 16);
|
||||
}
|
||||
|
||||
struct _MLX_BFloat16;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat struct
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct _MLX_BFloat16 {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Constructors
|
||||
uint16_t bits_;
|
||||
_MLX_BFloat16() thread = default;
|
||||
_MLX_BFloat16() threadgroup = default;
|
||||
_MLX_BFloat16() device = default;
|
||||
_MLX_BFloat16() constant = default;
|
||||
|
||||
struct bits_to_bfloat_struct {};
|
||||
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
||||
return bits_to_bfloat_struct();
|
||||
}
|
||||
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
||||
: bits_(bits) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions to bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions from bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const thread {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const threadgroup {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const device {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const constant {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat operators
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops
|
||||
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
||||
return -static_cast<float>(x);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Binary operators
|
||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
} \
|
||||
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Arithmetic Operators
|
||||
#define bfloat_binop(_op_, _operator_) \
|
||||
bfloat_binop_base( \
|
||||
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||
|
||||
bfloat_binop(+, operator+);
|
||||
bfloat_binop(-, operator-);
|
||||
bfloat_binop(*, operator*);
|
||||
bfloat_binop(/, operator/);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Comparison ops
|
||||
#define bfloat_compop(__op__, __operator__) \
|
||||
bfloat_binop_base( \
|
||||
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||
|
||||
bfloat_compop(>, operator>);
|
||||
bfloat_compop(<, operator<);
|
||||
bfloat_compop(>=, operator>=);
|
||||
bfloat_compop(<=, operator<=);
|
||||
bfloat_compop(==, operator==);
|
||||
bfloat_compop(!=, operator!=);
|
||||
|
||||
#undef bfloat_compop
|
||||
#undef bfloat_binop_base
|
||||
#undef bfloat_binop_helper
|
||||
#undef bfloat_binop
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Inplace Operators
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
} \
|
||||
constexpr METAL_FUNC addr_space itype& __operator__( \
|
||||
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
||||
|
||||
#define bfloat_inplace_op(itype) \
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
||||
|
||||
bfloat_inplace_op(float);
|
||||
bfloat_inplace_op(half);
|
||||
bfloat_inplace_op(int16_t);
|
||||
bfloat_inplace_op(int32_t);
|
||||
bfloat_inplace_op(int64_t);
|
||||
bfloat_inplace_op(uint16_t);
|
||||
bfloat_inplace_op(uint32_t);
|
||||
bfloat_inplace_op(uint64_t);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
#undef bfloat_inplace_op
|
||||
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
||||
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat typedef
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
typedef struct _MLX_BFloat16 bfloat16_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat numeric limits
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace metal {
|
||||
|
||||
template <>
|
||||
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
||||
static constexpr constant int digits = 8;
|
||||
static constexpr constant int digits10 = 2;
|
||||
static constexpr constant int max_digits10 = 4;
|
||||
static constexpr constant int radix = 2;
|
||||
static constexpr constant int min_exponent = -125;
|
||||
static constexpr constant int min_exponent10 = -37;
|
||||
static constexpr constant int max_exponent = 128;
|
||||
static constexpr constant int max_exponent10 = 38;
|
||||
|
||||
static constexpr bfloat16_t min() {
|
||||
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t lowest() {
|
||||
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t max() {
|
||||
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t epsilon() {
|
||||
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t round_error() {
|
||||
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t infinity() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t quiet_NaN() {
|
||||
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t signaling_NaN() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t denorm_min() {
|
||||
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
};
|
||||
|
||||
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
return x != x;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif // defined(__HAVE_BFLOAT__)
|
||||
|
||||
#include "gemm/bf16_math.h"
|
@ -1,394 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gemm/bf16.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal math for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
|
||||
Following the Metal Shading Language Specification (Metal 3.1)
|
||||
|
||||
"bfloat is an extended itypeing point type that only allows implicit conversion
|
||||
to a type of greater itypeing point rank. While bfloat can be implicitly
|
||||
converted to itype, it cannot be implicitly converted to half, and neither
|
||||
itype nor half can be implicitly converted to bfloat."
|
||||
|
||||
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
||||
for bfloat and calling with an argument of type bfloat will result in that
|
||||
argument getting implicitly converted to itype which then returns an output
|
||||
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
||||
|
||||
This leads to situations where
|
||||
bfloat a = 5.0bf;
|
||||
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
||||
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
||||
|
||||
For the moment, I will be adding overloaded instantiations of the math
|
||||
functions to accordingly automatically handle the casting
|
||||
|
||||
*/
|
||||
|
||||
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
||||
\
|
||||
METAL_FUNC otype abs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acos(itype x) { \
|
||||
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acosh(itype x) { \
|
||||
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asin(itype x) { \
|
||||
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asinh(itype x) { \
|
||||
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan(itype y_over_x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan2(itype y, itype x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atanh(itype x) { \
|
||||
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype ceil(itype x) { \
|
||||
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cos(itype x) { \
|
||||
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cosh(itype x) { \
|
||||
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cospi(itype x) { \
|
||||
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype divide(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp(itype x) { \
|
||||
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp10(itype x) { \
|
||||
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp2(itype x) { \
|
||||
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fabs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fdim(itype x, itype y) { \
|
||||
ctype t = static_cast<ctype>(x - y); \
|
||||
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
||||
} \
|
||||
METAL_FUNC otype floor(itype x) { \
|
||||
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fma( \
|
||||
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
||||
} \
|
||||
METAL_FUNC otype fmax(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmod(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fract(itype x) { \
|
||||
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
||||
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
||||
} \
|
||||
METAL_FUNC otype ldexp(itype x, int k) { \
|
||||
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log(itype x) { \
|
||||
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log10(itype x) { \
|
||||
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log2(itype x) { \
|
||||
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype nextafter(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
||||
} \
|
||||
METAL_FUNC otype pow(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype powr(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rint(itype x) { \
|
||||
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype round(itype x) { \
|
||||
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rsqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sin(itype x) { \
|
||||
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinh(itype x) { \
|
||||
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinpi(itype x) { \
|
||||
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tan(itype x) { \
|
||||
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanh(itype x) { \
|
||||
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanpi(itype x) { \
|
||||
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype trunc(itype x) { \
|
||||
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
||||
}
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_MAYBE_FAST_MATH__);
|
||||
|
||||
namespace fast {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_FAST_MATH__);
|
||||
|
||||
} // namespace fast
|
||||
|
||||
namespace precise {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_PRECISE_MATH__);
|
||||
|
||||
} // namespace precise
|
||||
|
||||
} // namespace metal
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal simd for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_metal_simd_comm_funcs( \
|
||||
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
||||
\
|
||||
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||
itype data, itype filling_data, ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), \
|
||||
itype_to_ctype(filling_data), \
|
||||
delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||
itype data, itype filling_data, ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), \
|
||||
itype_to_ctype(filling_data), \
|
||||
delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
||||
}
|
||||
|
||||
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
||||
\
|
||||
METAL_FUNC otype simd_max(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_min(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_product(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_sum(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_xor(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
||||
#else
|
||||
|
||||
#define bfloat16_to_uint16(x) x.bits_
|
||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||
|
||||
#endif
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_simd_comm_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
uint16_t,
|
||||
bfloat16_to_uint16,
|
||||
uint16_to_bfloat16);
|
||||
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
||||
|
||||
} // namespace metal
|
@ -1,131 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct complex64_t;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_complex64 =
|
||||
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_complex64 =
|
||||
!is_same_v<T, complex64_t> &&
|
||||
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
||||
|
||||
struct complex64_t {
|
||||
float real;
|
||||
float imag;
|
||||
|
||||
// Constructors
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
||||
|
||||
// Conversions to complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||
|
||||
// Conversions from complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const thread {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const threadgroup {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const device {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const constant {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr complex64_t operator-(complex64_t x) {
|
||||
return {-x.real, -x.imag};
|
||||
}
|
||||
|
||||
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator>(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
||||
return operator>=(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator<(complex64_t a, complex64_t b) {
|
||||
return operator>(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
return a.real == b.real && a.imag == b.imag;
|
||||
}
|
||||
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
return {a.real - b.real, a.imag - b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
auto denom = b.real * b.real + b.imag * b.imag;
|
||||
auto x = a.real * b.real + a.imag * b.imag;
|
||||
auto y = a.imag * b.real - a.real * b.imag;
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
if (real != 0 && (real < 0 != b.real < 0)) {
|
||||
real += b.real;
|
||||
}
|
||||
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
||||
imag += b.imag;
|
||||
}
|
||||
return {real, imag};
|
||||
}
|
@ -1,292 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gemm/loader.h"
|
||||
#include "gemm/mma.h"
|
||||
#include "gemm/transforms.h"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel class
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||
struct LoopAlignment {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct GEMMKernel {
|
||||
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
STEEL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||
|
||||
using loader_a_t = BlockLoader<
|
||||
T,
|
||||
transpose_a ? BK : BM,
|
||||
transpose_a ? BM : BK,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
!transpose_a,
|
||||
tgp_size>;
|
||||
using loader_b_t = BlockLoader<
|
||||
T,
|
||||
transpose_b ? BN : BK,
|
||||
transpose_b ? BK : BN,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
transpose_b,
|
||||
tgp_size>;
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||
static METAL_FUNC void gemm_loop(
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
const int gemm_k_iterations,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
thread mma_t& mma_op,
|
||||
thread const short& tgp_bm,
|
||||
thread const short& tgp_bn,
|
||||
thread const short& lbk,
|
||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned_) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* C [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, params->ldc);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_loop<true, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result(C, params->ldc);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_loop<false, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_loop<true, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else {
|
||||
gemm_loop<false, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@ -1,5 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "params.h"
|
@ -1,89 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "gemm/bf16.h"
|
||||
#include "gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, C,
|
||||
params,
|
||||
As, Bs,
|
||||
simd_lane_id, simd_group_id, tid, lid
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *C [[buffer(2)]], \
|
||||
const constant GEMMParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
@ -1,254 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformAdd<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
const device T *C [[buffer(2)]],
|
||||
device T *D [[buffer(3)]],
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel =
|
||||
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
||||
transpose_a, transpose_b,
|
||||
MN_aligned, K_aligned,
|
||||
AccumType, Epilogue>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col * params->fdc;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
const Epilogue epilogue_op(params->alpha, params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, K_aligned>{});
|
||||
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
||||
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
@ -1,280 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device U *C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
const int tid_x = tid.x;
|
||||
const int tid_y = tid.y;
|
||||
const int tid_z = tid.z;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int k_start = params->split_k_partition_size * tid_z;
|
||||
|
||||
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
||||
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
||||
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K % BK;
|
||||
|
||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||
if(!K_aligned || gemm_k_iter_remaining > 0)
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iter_remaining,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
}
|
||||
|
||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
mma_op.store_result(C, params->ldc);
|
||||
} else {
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device otype *C [[buffer(2)]], \
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Split k accumulation kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformNone<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum(
|
||||
const device AccT *C_split [[buffer(0)]],
|
||||
device OutT *D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
// Ajust D and C
|
||||
D += gid.x + gid.y * ldd;
|
||||
C_split += gid.x + gid.y * ldd;
|
||||
|
||||
int offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for(int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
D[0] = Epilogue::apply(out);
|
||||
|
||||
}
|
||||
|
||||
template <typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum_axpby(
|
||||
const device AccT *C_split [[buffer(0)]],
|
||||
device OutT *D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device OutT *C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
// Ajust D and C
|
||||
C += gid.x * fdc + gid.y * ldc;
|
||||
D += gid.x + gid.y * ldd;
|
||||
C_split += gid.x + gid.y * ldd;
|
||||
|
||||
int offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for(int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
Epilogue op(alpha, beta);
|
||||
D[0] = op.apply(out, *C);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
||||
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
||||
const device atype *C_split [[buffer(0)]], \
|
||||
device otype *D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
uint2 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
||||
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
||||
const device atype *C_split [[buffer(0)]], \
|
||||
device otype *D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
const device otype *C [[buffer(5)]], \
|
||||
const constant int& ldc [[buffer(6)]], \
|
||||
const constant int& fdc [[buffer(7)]], \
|
||||
const constant float& alpha [[buffer(8)]], \
|
||||
const constant float& beta [[buffer(9)]], \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_accum(float16, half, float32, float);
|
||||
instantiate_accum(float32, float, float32, float);
|
@ -1,125 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils2.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short alignment = 1,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoader {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||
uint8_t v[sizeof(T) * vec_size];
|
||||
};
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoader(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||
*((const device ReadVector*)(&src[i * src_ld]));
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@ -1,264 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gemm/transforms.h"
|
||||
#include "utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMA {
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = 8 * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
STEEL_CONST short simd_stride_a = {
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
||||
STEEL_CONST short simd_stride_b = {
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
||||
|
||||
// Jump between elements
|
||||
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
||||
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
||||
|
||||
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
||||
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
// Offsets within threadgroup
|
||||
const short tm;
|
||||
const short tn;
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
short As_offset;
|
||||
short Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset =
|
||||
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
||||
Bs_offset =
|
||||
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of 8
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||
Asimd[i].thread_elements()[1] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||
Bsimd[j].thread_elements()[1] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j_serp],
|
||||
Asimd[i],
|
||||
Bsimd[j_serp],
|
||||
results[i * TN + j_serp]);
|
||||
}
|
||||
}
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
Bs += tile_stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||
|
||||
// Write out C
|
||||
C[offset] = outs[0];
|
||||
C[offset + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {
|
||||
epilogue_op.apply(accum[0], C[offset_c]),
|
||||
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
||||
|
||||
// Write out D
|
||||
D[offset_d] = outs[0];
|
||||
D[offset_d + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@ -1,79 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM param classes
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
struct GEMMParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_c;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
};
|
||||
|
||||
struct GEMMSpiltKParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int split_k_partitions;
|
||||
const int split_k_partition_stride;
|
||||
const int split_k_partition_size;
|
||||
|
||||
const int gemm_k_iterations_aligned;
|
||||
};
|
||||
|
||||
struct GEMMAddMMParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
const int ldd;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_c;
|
||||
const int batch_stride_d;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
const int fdc;
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
Binary file not shown.
@ -1,63 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms and Epilogues
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAdd {
|
||||
TransformAdd(const float, const float) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||
return static_cast<OutT>(x) + c;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAxpby {
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
TransformAxpby(const float alpha_, const float beta_)
|
||||
: alpha(alpha_), beta(beta_) {}
|
||||
|
||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||
return static_cast<OutT>(x * alpha + (beta * c));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
struct BlockSwizzle {
|
||||
static METAL_FUNC int2
|
||||
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
||||
const int tid_x = (tid.x) >> swizzle_log;
|
||||
const int tid_y =
|
||||
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
||||
return int2(tid_x, tid_y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@ -1,276 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
#include "gemm/bf16.h"
|
||||
#include "gemm/complex.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const constant U max = metal::numeric_limits<U>::max();
|
||||
static const constant U min = metal::numeric_limits<U>::min();
|
||||
static const constant U finite_max = metal::numeric_limits<U>::max();
|
||||
static const constant U finite_min = metal::numeric_limits<U>::min();
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
metal::numeric_limits<type>::min(); \
|
||||
};
|
||||
|
||||
instantiate_default_limit(uint8_t);
|
||||
instantiate_default_limit(uint16_t);
|
||||
instantiate_default_limit(uint32_t);
|
||||
instantiate_default_limit(uint64_t);
|
||||
instantiate_default_limit(int8_t);
|
||||
instantiate_default_limit(int16_t);
|
||||
instantiate_default_limit(int32_t);
|
||||
instantiate_default_limit(int64_t);
|
||||
|
||||
#define instantiate_float_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr constant type max = \
|
||||
metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type min = \
|
||||
-metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
-metal::numeric_limits<type>::max(); \
|
||||
};
|
||||
|
||||
instantiate_float_limit(half);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr constant bool max = true;
|
||||
static constexpr constant bool min = false;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline size_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t strides[NDIM]) {
|
||||
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
||||
return elem * stride;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
inline size_t elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint elem_to_loc_nd(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides);
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<1>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
return (elem % shape[0]) * strides[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<2>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<3>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<4>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[3]) * strides[3];
|
||||
elem /= shape[3];
|
||||
loc += (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** Compute ceil((float)N/(float)M) */
|
||||
inline size_t ceildiv(size_t N, size_t M) {
|
||||
return (N + M - 1) / M;
|
||||
}
|
||||
|
||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||
inline float log1p(float x) {
|
||||
float xp1 = 1.0f + x;
|
||||
if (xp1 == Limits<float>::max) {
|
||||
return Limits<float>::max;
|
||||
}
|
||||
if (xp1 == 1.0f) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||
}
|
||||
|
||||
inline bfloat16_t log1p(bfloat16_t x) {
|
||||
float xp1 = 1.0f + static_cast<float>(x);
|
||||
if (xp1 == Limits<float>::max) {
|
||||
return Limits<bfloat16_t>::max;
|
||||
}
|
||||
if (xp1 == 1.0f) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SIMD shuffle ops
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
||||
return as_type<uint64_t>(
|
||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||
return as_type<int64_t>(
|
||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "gemm/host.h"
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
@ -1,7 +1,6 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize,
|
||||
NSUInteger,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -17,7 +16,6 @@ const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const GEMM: &[u8] = include_bytes!("gemm/steel_gemm.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
@ -124,7 +122,6 @@ pub enum Source {
|
||||
Cast,
|
||||
Reduce,
|
||||
Mfa,
|
||||
Gemm,
|
||||
Conv,
|
||||
Random,
|
||||
Quantized,
|
||||
@ -186,7 +183,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip
|
||||
tanh, recip, silu
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -251,7 +248,6 @@ impl Kernels {
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
Source::Gemm => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -275,14 +271,6 @@ impl Kernels {
|
||||
))
|
||||
})?
|
||||
}
|
||||
Source::Gemm => {
|
||||
let source_data = GEMM;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load GEMM: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
@ -1242,34 +1230,6 @@ impl ConstantValues {
|
||||
}
|
||||
}
|
||||
|
||||
fn string_to_static_str(s: String) -> &'static str {
|
||||
Box::leak(s.into_boxed_str())
|
||||
}
|
||||
|
||||
use core::ffi::c_int;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
struct GEMMParams {
|
||||
m: c_int,
|
||||
n: c_int,
|
||||
k: c_int,
|
||||
|
||||
lda: c_int,
|
||||
ldb: c_int,
|
||||
ldc: c_int,
|
||||
|
||||
tiles_n: c_int,
|
||||
tiles_m: c_int,
|
||||
|
||||
batch_stride_a: c_int,
|
||||
batch_stride_b: c_int,
|
||||
batch_stride_c: c_int,
|
||||
|
||||
swizzle_log: c_int,
|
||||
gemm_k_iterations_aligned: c_int,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_gemm(
|
||||
device: &Device,
|
||||
@ -1291,10 +1251,10 @@ pub fn call_gemm(
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
let (a_trans, lda) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
(false, k as c_int)
|
||||
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
false
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(true, n as c_int)
|
||||
true
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
@ -1302,10 +1262,10 @@ pub fn call_gemm(
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
let (b_trans, ldb) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
(false, n as c_int)
|
||||
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
false
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(true, k as c_int)
|
||||
true
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
@ -1313,195 +1273,119 @@ pub fn call_gemm(
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
// let d_trans = false;
|
||||
// let alpha = 1.0f32;
|
||||
// let beta = 0.0f32;
|
||||
// let batched = b > 1;
|
||||
// let fused_activation = false;
|
||||
// let fused_bias = false;
|
||||
// let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||
// let m_simd = 8;
|
||||
// let n_simd = 8;
|
||||
// let k_simd = 64;
|
||||
// let m_splits = 1;
|
||||
// let n_splits = 1;
|
||||
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
// } else {
|
||||
// let m_simd = 40;
|
||||
// let n_simd = 40;
|
||||
// let k_simd = 32;
|
||||
// let m_splits = 1;
|
||||
// let n_splits = 1;
|
||||
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
// };
|
||||
// let constants = Some(ConstantValues::new(vec![
|
||||
// (0, Value::USize(m)),
|
||||
// (1, Value::USize(n)),
|
||||
// (2, Value::USize(k)),
|
||||
// (10, Value::Bool(a_trans)),
|
||||
// (11, Value::Bool(b_trans)),
|
||||
// (13, Value::Bool(d_trans)),
|
||||
// (20, Value::F32(alpha)),
|
||||
// (21, Value::F32(beta)),
|
||||
// (100, Value::Bool(batched)),
|
||||
// (101, Value::Bool(fused_activation)),
|
||||
// // Garbage
|
||||
// (102, Value::Bool(false)),
|
||||
// (103, Value::Bool(false)),
|
||||
// (113, Value::Bool(false)),
|
||||
// (50_000, Value::Bool(false)),
|
||||
// // End garbage
|
||||
// (200, Value::U16(m_simd)),
|
||||
// (201, Value::U16(n_simd)),
|
||||
// (202, Value::U16(k_simd)),
|
||||
// (210, Value::U16(m_splits)),
|
||||
// (211, Value::U16(n_splits)),
|
||||
// (50_001, Value::Bool(fused_bias)),
|
||||
// ]));
|
||||
let a_trans_name = if a_trans { "t" } else { "n" };
|
||||
let b_trans_name = if b_trans { "t" } else { "n" };
|
||||
let (iname, oname) = match name {
|
||||
"sgemm" => ("float32", "float32"),
|
||||
"hgemm" => ("float16", "float16"),
|
||||
"bgemm" => ("bfloat16", "bfloat16"),
|
||||
let d_trans = false;
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
let batched = b > 1;
|
||||
let fused_activation = false;
|
||||
let fused_bias = false;
|
||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||
let m_simd = 8;
|
||||
let n_simd = 8;
|
||||
let k_simd = 64;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
} else {
|
||||
let m_simd = 40;
|
||||
let n_simd = 40;
|
||||
let k_simd = 32;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
};
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(0, Value::USize(m)),
|
||||
(1, Value::USize(n)),
|
||||
(2, Value::USize(k)),
|
||||
(10, Value::Bool(a_trans)),
|
||||
(11, Value::Bool(b_trans)),
|
||||
(13, Value::Bool(d_trans)),
|
||||
(20, Value::F32(alpha)),
|
||||
(21, Value::F32(beta)),
|
||||
(100, Value::Bool(batched)),
|
||||
(101, Value::Bool(fused_activation)),
|
||||
// Garbage
|
||||
(102, Value::Bool(false)),
|
||||
(103, Value::Bool(false)),
|
||||
(113, Value::Bool(false)),
|
||||
(50_000, Value::Bool(false)),
|
||||
// End garbage
|
||||
(200, Value::U16(m_simd)),
|
||||
(201, Value::U16(n_simd)),
|
||||
(202, Value::U16(k_simd)),
|
||||
(210, Value::U16(m_splits)),
|
||||
(211, Value::U16(n_splits)),
|
||||
(50_001, Value::Bool(fused_bias)),
|
||||
]));
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||
let m_group = m_simd * m_splits;
|
||||
let n_group = n_simd * n_splits;
|
||||
|
||||
let a_block_length = m_group * k_simd;
|
||||
let b_block_length = k_simd * n_group;
|
||||
|
||||
let mut block_elements = a_block_length + b_block_length;
|
||||
if (m % 8 != 0) && (n % 8 != 0) {
|
||||
let c_block_length = m_group * n_group;
|
||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
}
|
||||
if fused_bias {
|
||||
if d_trans {
|
||||
block_elements = std::cmp::max(block_elements, m_group);
|
||||
} else {
|
||||
block_elements = std::cmp::max(block_elements, n_group);
|
||||
}
|
||||
}
|
||||
let bytes = match name {
|
||||
"sgemm" => 4,
|
||||
"hgemm" => 2,
|
||||
other => {
|
||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||
"{other} is not a valid kernel for gemm"
|
||||
)))
|
||||
)));
|
||||
}
|
||||
};
|
||||
let mut bm = 32;
|
||||
let mut bn = 32;
|
||||
let mut bk = 16;
|
||||
let wm = 2;
|
||||
let wn = 2;
|
||||
if b * m * n >= 1 << 20 {
|
||||
if !a_trans && b_trans {
|
||||
bm = 64;
|
||||
bn = if oname == "float32" { 64 } else { 32 };
|
||||
bk = if oname == "float32" { 16 } else { 32 };
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
let mnaligned = if m % bm == 0 && n % bn == 0 {
|
||||
"taligned"
|
||||
} else {
|
||||
"naligned"
|
||||
};
|
||||
let kaligned = if k % bk == 0 { "taligned" } else { "naligned" };
|
||||
// let bytes = match &name[..] {
|
||||
// "sgemm" => 4,
|
||||
// "hgemm" => 2,
|
||||
// other => {
|
||||
// return Err(MetalKernelError::LoadLibraryError(format!(
|
||||
// "{other} is not a valid kernel for gemm"
|
||||
// )));
|
||||
// }
|
||||
// };
|
||||
let name = format!("steel_gemm_{a_trans_name}{b_trans_name}_{iname}_{oname}_bm{bm}_bn{bn}_bk{bk}_wm{wm}_wn{wn}_MN_{mnaligned}_K_{kaligned}");
|
||||
let name = string_to_static_str(name);
|
||||
let pipeline = kernels.load_pipeline(device, Source::Gemm, name)?;
|
||||
// let m_group = m_simd * m_splits;
|
||||
// let n_group = n_simd * n_splits;
|
||||
|
||||
// let a_block_length = m_group * k_simd;
|
||||
// let b_block_length = k_simd * n_group;
|
||||
|
||||
// let mut block_elements = a_block_length + b_block_length;
|
||||
// if (m % 8 != 0) && (n % 8 != 0) {
|
||||
// let c_block_length = m_group * n_group;
|
||||
// block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
// }
|
||||
// if fused_bias {
|
||||
// if d_trans {
|
||||
// block_elements = std::cmp::max(block_elements, m_group);
|
||||
// } else {
|
||||
// block_elements = std::cmp::max(block_elements, n_group);
|
||||
// }
|
||||
// }
|
||||
// let block_bytes = block_elements * bytes;
|
||||
let block_bytes = block_elements * bytes;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
|
||||
let batch_stride_a: i32 = if lhs_stride.len() > 2 {
|
||||
lhs_stride[lhs_stride.len() - 3] as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let batch_stride_b: i32 = if rhs_stride.len() > 2 {
|
||||
rhs_stride[rhs_stride.len() - 3] as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let batch_stride_c = (m * n) as i32;
|
||||
|
||||
let swizzle_log = 0;
|
||||
let tiles_n = ((n + bn - 1) / bn) as c_int;
|
||||
let tiles_m = ((m + bm - 1) / bm) as c_int;
|
||||
|
||||
let params = GEMMParams {
|
||||
m: m as c_int,
|
||||
n: n as c_int,
|
||||
k: k as c_int,
|
||||
lda,
|
||||
ldb,
|
||||
ldc: n as c_int,
|
||||
tiles_m,
|
||||
tiles_n,
|
||||
batch_stride_a,
|
||||
batch_stride_b,
|
||||
batch_stride_c,
|
||||
swizzle_log,
|
||||
gemm_k_iterations_aligned: (k / bk) as c_int,
|
||||
};
|
||||
let params_buffer = device.new_buffer_with_data(
|
||||
¶ms as *const GEMMParams as *const c_void,
|
||||
core::mem::size_of::<GEMMParams>() as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(2, Some(output), 0);
|
||||
encoder.set_buffer(3, Some(¶ms_buffer), 0);
|
||||
// TODO Tensor D
|
||||
|
||||
let grid_z = b;
|
||||
// if batched {
|
||||
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||
// let byte_stride_c = m * n * bytes as usize;
|
||||
// // TODO byte_stride_d
|
||||
// let byte_stride_d = 0;
|
||||
if batched {
|
||||
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_c = m * n * bytes as usize;
|
||||
// TODO byte_stride_d
|
||||
let byte_stride_d = 0;
|
||||
|
||||
// let buffer: Vec<u64> = vec![
|
||||
// byte_stride_a as _,
|
||||
// byte_stride_b as _,
|
||||
// byte_stride_c as _,
|
||||
// byte_stride_d as _,
|
||||
// ];
|
||||
// // encoder.set_bytes(
|
||||
// // 10,
|
||||
// // (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||
// // buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||
// // );
|
||||
// }
|
||||
let tile = 1 << swizzle_log;
|
||||
let tm = (tiles_m + tile - 1) / tile;
|
||||
let tn = tiles_n * tile;
|
||||
let buffer: Vec<u64> = vec![
|
||||
byte_stride_a as _,
|
||||
byte_stride_b as _,
|
||||
byte_stride_c as _,
|
||||
byte_stride_d as _,
|
||||
];
|
||||
encoder.set_bytes(
|
||||
10,
|
||||
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||
);
|
||||
}
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: tn as u64,
|
||||
height: tm as u64,
|
||||
width: divide(n, n_group.into()),
|
||||
height: divide(m, m_group.into()),
|
||||
depth: grid_z as NSUInteger,
|
||||
};
|
||||
let group_size = MTLSize {
|
||||
width: 32,
|
||||
height: wn,
|
||||
depth: wm,
|
||||
width: 32 * (m_splits as u64) * (n_splits as u64),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
|
@ -231,6 +231,25 @@ fn gelu_f32() {
|
||||
assert_eq!(approx(results, 3), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silu_f16() {
|
||||
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::silu::HALF);
|
||||
assert_eq!(approx_f16(results, 2), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silu_f32() {
|
||||
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
||||
let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::silu::FLOAT);
|
||||
assert_eq!(approx(results, 3), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_add_f32() {
|
||||
let left = vec![1.0f32, 2.0, 3.0];
|
||||
|
@ -64,6 +64,9 @@ template <typename T> METAL_FUNC T relu(T in){
|
||||
}
|
||||
return in;
|
||||
}
|
||||
template <typename T> METAL_FUNC T silu(T in){
|
||||
return in / (static_cast<T>(1) + exp(-in));
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -108,6 +111,7 @@ UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY_OP(gelu)
|
||||
UNARY_OP(silu)
|
||||
UNARY_OP(abs)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
@ -135,6 +139,7 @@ BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(silu)
|
||||
BFLOAT_UNARY_OP(abs)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
|
@ -30,7 +30,7 @@ impl super::Module for Activation {
|
||||
Self::Relu => xs.relu(),
|
||||
Self::Relu2 => xs.relu()?.sqr(),
|
||||
Self::Relu6 => xs.clamp(0f32, 6f32),
|
||||
Self::Silu => crate::ops::silu(xs),
|
||||
Self::Silu => xs.silu(),
|
||||
Self::Sigmoid => crate::ops::sigmoid(xs),
|
||||
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
|
||||
Self::Swiglu => crate::ops::swiglu(xs),
|
||||
|
@ -35,13 +35,12 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Should we have a specialized op for this?
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
xs.silu()
|
||||
}
|
||||
|
||||
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.chunk(2, candle::D::Minus1)?;
|
||||
crate::ops::silu(&xs[0])? * &xs[1]
|
||||
&xs[0].silu()? * &xs[1]
|
||||
}
|
||||
|
||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
|
@ -2,10 +2,16 @@
|
||||
//!
|
||||
//! See "A ConvNet for the 2020s" Liu et al. 2022
|
||||
//! <https://arxiv.org/abs/2201.03545>
|
||||
//! and
|
||||
//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023
|
||||
//! <https://arxiv.org/abs/2301.00808>
|
||||
|
||||
//! Original code: https://github.com/facebookresearch/ConvNeXt/
|
||||
//! Original code:
|
||||
//! https://github.com/facebookresearch/ConvNeXt/
|
||||
//! https://github.com/facebookresearch/ConvNeXt-V2/
|
||||
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
|
||||
|
||||
use candle::shape::ShapeWithOneHole;
|
||||
use candle::{Result, D};
|
||||
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
||||
|
||||
@ -13,31 +19,71 @@ use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
||||
pub struct Config {
|
||||
blocks: [usize; 4],
|
||||
channels: [usize; 4],
|
||||
use_conv_mlp: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn atto() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 6, 2],
|
||||
channels: [40, 80, 160, 320],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn femto() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 6, 2],
|
||||
channels: [48, 96, 192, 384],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pico() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 6, 2],
|
||||
channels: [64, 128, 256, 512],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nano() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 8, 2],
|
||||
channels: [80, 160, 320, 640],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 9, 3],
|
||||
channels: [96, 192, 384, 768],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn small() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [96, 192, 384, 768],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [128, 256, 512, 1024],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn large() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [192, 384, 768, 1536],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,8 +91,68 @@ impl Config {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [256, 512, 1024, 2048],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn huge() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [352, 704, 1408, 2816],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Layer norm for data in channels-last format.
|
||||
fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&norm)))
|
||||
}
|
||||
|
||||
// Layer norm for data in channels-first format.
|
||||
fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Global response normalization layer
|
||||
// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py
|
||||
fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let (shape, spatial_dim, channel_dim) = if channels_last {
|
||||
((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
|
||||
} else {
|
||||
((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
|
||||
};
|
||||
|
||||
let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
|
||||
let beta = vb.get(dim, "bias")?.reshape(&shape)?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let gx = xs
|
||||
.sqr()?
|
||||
.sum_keepdim(spatial_dim)?
|
||||
.mean_keepdim(spatial_dim)?
|
||||
.sqrt()?;
|
||||
|
||||
let gxmean = gx.mean_keepdim(channel_dim)?;
|
||||
let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
|
||||
let xs = xs
|
||||
.broadcast_mul(&nx)?
|
||||
.broadcast_mul(&gamma)?
|
||||
.broadcast_add(&beta)?;
|
||||
|
||||
xs + residual
|
||||
}))
|
||||
}
|
||||
|
||||
// Initial downsampling via a patchify layer.
|
||||
@ -56,16 +162,9 @@ fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
..Default::default()
|
||||
};
|
||||
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
|
||||
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
// The layer norm works with channels-last format.
|
||||
let xs = xs
|
||||
.apply(&patchify)?
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
Ok(xs)
|
||||
}))
|
||||
let norm = layer_norm_cf(out_channels, vb.pp(1))?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
|
||||
}
|
||||
|
||||
// Downsampling applied after the stages.
|
||||
@ -74,31 +173,49 @@ fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?;
|
||||
let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
|
||||
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.apply(&conv)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
|
||||
}
|
||||
|
||||
// MLP equivalent of pointwise convolutions.
|
||||
// MLP block from the original paper with optional GRN layer (v2 models).
|
||||
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
|
||||
let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
|
||||
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||
if let Ok(g) = &grn {
|
||||
xs = xs.apply(g)?;
|
||||
}
|
||||
xs = xs.apply(&fc2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// A block consisting of a depthwise convolution, a MLP and layer scaling.
|
||||
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
// MLP block using pointwise convolutions, with optional GRN layer (v2 models).
|
||||
fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||
if let Ok(g) = &grn {
|
||||
xs = xs.apply(g)?;
|
||||
}
|
||||
xs = xs.apply(&fc2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only).
|
||||
fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
groups: dim,
|
||||
padding: 3,
|
||||
@ -106,20 +223,36 @@ fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
};
|
||||
|
||||
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
|
||||
let gamma = vb.get(dim, "gamma");
|
||||
|
||||
let gamma = vb.get(dim, "gamma")?;
|
||||
let mlp = convnext_mlp(dim, vb.pp("mlp"))?;
|
||||
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
|
||||
let (mlp, norm) = if use_conv_mlp {
|
||||
(
|
||||
convnext_conv_mlp(dim, vb.pp("mlp"))?,
|
||||
layer_norm_cf(dim, vb.pp("norm"))?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
convnext_mlp(dim, vb.pp("mlp"))?,
|
||||
layer_norm_cl(dim, vb.pp("norm"))?,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let xs = xs
|
||||
.apply(&conv_dw)?
|
||||
.permute((0, 2, 3, 1))?
|
||||
let mut xs = xs.apply(&conv_dw)?;
|
||||
|
||||
xs = if use_conv_mlp {
|
||||
xs.apply(&norm)?.apply(&mlp)?
|
||||
} else {
|
||||
xs.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.apply(&mlp)?
|
||||
.broadcast_mul(&gamma)?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
.permute((0, 3, 1, 2))?
|
||||
};
|
||||
|
||||
if let Ok(g) = &gamma {
|
||||
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
|
||||
};
|
||||
|
||||
xs + residual
|
||||
}))
|
||||
@ -137,7 +270,11 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
|
||||
}
|
||||
|
||||
for block_idx in 0..nblocks {
|
||||
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?);
|
||||
blocks.push(convnext_block(
|
||||
dim,
|
||||
cfg.use_conv_mlp,
|
||||
vb.pp(format!("blocks.{block_idx}")),
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
@ -149,8 +286,9 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
|
||||
}))
|
||||
}
|
||||
|
||||
// Classification head.
|
||||
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
|
||||
let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
|
||||
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
|
||||
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
|
||||
}
|
||||
|
@ -1,13 +1,12 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct LlamaConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
|
@ -34,6 +34,7 @@ pub mod quantized_t5;
|
||||
pub mod qwen2;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod rwkv_v5;
|
||||
pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
@ -41,6 +42,7 @@ pub mod t5;
|
||||
pub mod trocr;
|
||||
pub mod vgg;
|
||||
pub mod vit;
|
||||
pub mod vocos;
|
||||
pub mod whisper;
|
||||
pub mod with_tracing;
|
||||
pub mod wuerstchen;
|
||||
|
409
candle-transformers/src/models/rwkv_v5.rs
Normal file
409
candle-transformers/src/models/rwkv_v5.rs
Normal file
@ -0,0 +1,409 @@
|
||||
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
fn default_num_attention_heads() -> usize {
|
||||
64
|
||||
}
|
||||
|
||||
// https://huggingface.co/RWKV/HF_v5-Eagle-7B/blob/main/configuration_rwkv5.py
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub attention_hidden_size: usize,
|
||||
#[serde(default = "default_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
pub head_size: usize,
|
||||
pub intermediate_size: Option<usize>,
|
||||
pub layer_norm_epsilon: f64,
|
||||
pub rescale_every: usize,
|
||||
}
|
||||
|
||||
struct StatePerLayer {
|
||||
extract_key_value: Tensor,
|
||||
linear_attention: Tensor,
|
||||
feed_forward: Tensor,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
per_layer: Vec<StatePerLayer>,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
// Certainly a weird convention but taken from modeling_rwkv5.py
|
||||
let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads;
|
||||
for _layer_idx in 0..cfg.num_hidden_layers {
|
||||
let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
|
||||
let linear_attention = Tensor::zeros(
|
||||
(
|
||||
batch_size,
|
||||
num_attention_heads,
|
||||
cfg.hidden_size / num_attention_heads,
|
||||
cfg.hidden_size / num_attention_heads,
|
||||
),
|
||||
DType::F32,
|
||||
dev,
|
||||
)?;
|
||||
let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
|
||||
per_layer.push(StatePerLayer {
|
||||
extract_key_value,
|
||||
linear_attention,
|
||||
feed_forward,
|
||||
});
|
||||
}
|
||||
Ok(Self { per_layer, pos: 0 })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SelfAttention {
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
gate: Linear,
|
||||
output: Linear,
|
||||
ln_x: candle_nn::GroupNorm,
|
||||
time_mix_key: Tensor,
|
||||
time_mix_value: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
time_decay: Tensor,
|
||||
time_faaaa: Tensor,
|
||||
time_mix_gate: Tensor,
|
||||
layer_id: usize,
|
||||
n_attn_heads: usize,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let attn_hidden_size = cfg.attention_hidden_size;
|
||||
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
|
||||
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
|
||||
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
|
||||
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
|
||||
let ln_x = candle_nn::group_norm(
|
||||
hidden_size / cfg.head_size,
|
||||
hidden_size,
|
||||
1e-5,
|
||||
vb.pp("ln_x"),
|
||||
)?;
|
||||
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||
let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
|
||||
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||
let n_attn_heads = cfg.hidden_size / cfg.head_size;
|
||||
let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?;
|
||||
let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
|
||||
let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
|
||||
Ok(Self {
|
||||
key,
|
||||
value,
|
||||
receptance,
|
||||
gate,
|
||||
output,
|
||||
ln_x,
|
||||
time_mix_key,
|
||||
time_mix_value,
|
||||
time_mix_receptance,
|
||||
time_decay,
|
||||
time_faaaa,
|
||||
time_mix_gate,
|
||||
layer_id,
|
||||
n_attn_heads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let h = self.time_decay.dim(0)?;
|
||||
let (b, t, s) = xs.dims3()?;
|
||||
let s = s / h;
|
||||
let (receptance, key, value, gate) = {
|
||||
// exctract key-value
|
||||
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
|
||||
let shifted = if shifted.rank() == 2 {
|
||||
shifted.unsqueeze(1)?
|
||||
} else {
|
||||
shifted
|
||||
};
|
||||
let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
|
||||
let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
|
||||
let receptance = ((xs * &self.time_mix_receptance)?
|
||||
+ &shifted * (1.0 - &self.time_mix_receptance)?)?;
|
||||
let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
|
||||
|
||||
let key = self.key.forward(&key)?;
|
||||
let value = self.value.forward(&value)?;
|
||||
let receptance = self.receptance.forward(&receptance)?;
|
||||
let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
|
||||
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
|
||||
(receptance, key, value, gate)
|
||||
};
|
||||
// linear attention
|
||||
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
|
||||
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
|
||||
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
|
||||
let time_decay = self
|
||||
.time_decay
|
||||
.exp()?
|
||||
.neg()?
|
||||
.exp()?
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
let time_faaaa =
|
||||
self.time_faaaa
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
|
||||
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||
for t_ in 0..t {
|
||||
//
|
||||
let rt = receptance.i((.., .., t_..t_ + 1))?;
|
||||
let kt = key.i((.., .., .., t_..t_ + 1))?;
|
||||
let vt = value.i((.., .., t_..t_ + 1))?;
|
||||
let at = kt.matmul(&vt)?;
|
||||
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||
state_ = (&at + time_decay.broadcast_mul(&state_))?;
|
||||
out.push(out_)
|
||||
}
|
||||
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
|
||||
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
|
||||
let out = (out * gate)?.apply(&self.output)?;
|
||||
state.per_layer[self.layer_id].linear_attention = state_;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FeedForward {
|
||||
time_mix_key: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
layer_id: usize,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let int_size = cfg
|
||||
.intermediate_size
|
||||
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
|
||||
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
|
||||
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
|
||||
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||
Ok(Self {
|
||||
key,
|
||||
receptance,
|
||||
value,
|
||||
time_mix_key,
|
||||
time_mix_receptance,
|
||||
layer_id,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let shifted = &state.per_layer[self.layer_id].feed_forward;
|
||||
let key = (xs.broadcast_mul(&self.time_mix_key)?
|
||||
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
|
||||
let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
|
||||
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
|
||||
let key = key.apply(&self.key)?.relu()?.sqr()?;
|
||||
let value = key.apply(&self.value)?;
|
||||
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
|
||||
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
|
||||
let xs = (receptance * value)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
pre_ln: Option<LayerNorm>,
|
||||
ln1: LayerNorm,
|
||||
ln2: LayerNorm,
|
||||
attention: SelfAttention,
|
||||
feed_forward: FeedForward,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
|
||||
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
|
||||
let pre_ln = if layer_id == 0 {
|
||||
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
|
||||
Some(ln)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
|
||||
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
|
||||
Ok(Self {
|
||||
pre_ln,
|
||||
ln1,
|
||||
ln2,
|
||||
attention,
|
||||
feed_forward,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let xs = match self.pre_ln.as_ref() {
|
||||
None => xs.clone(),
|
||||
Some(pre_ln) => xs.apply(pre_ln)?,
|
||||
};
|
||||
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
|
||||
let xs = (xs + attention)?;
|
||||
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
|
||||
let xs = (xs + feed_forward)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embeddings: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_out: LayerNorm,
|
||||
head: Linear,
|
||||
rescale_every: usize,
|
||||
layers_are_rescaled: bool,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("rwkv");
|
||||
let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
|
||||
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_b = vb_m.pp("blocks");
|
||||
for block_index in 0..cfg.num_hidden_layers {
|
||||
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
|
||||
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
blocks,
|
||||
ln_out,
|
||||
head,
|
||||
rescale_every: cfg.rescale_every,
|
||||
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let (_b_size, _seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embeddings)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
xs = block.forward(&xs, state)?;
|
||||
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
|
||||
xs = (xs / 2.)?
|
||||
}
|
||||
}
|
||||
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
|
||||
state.pos += 1;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
type Bytes = Vec<u8>;
|
||||
|
||||
// https://github.com/BlinkDL/ChatRWKV/blob/095e812aef15a1f74107f6c39d13578a2412dc46/RWKV_v5_demo.py#L14
|
||||
pub struct Tokenizer {
|
||||
table: Vec<Vec<Vec<Bytes>>>,
|
||||
good: Vec<HashSet<u8>>,
|
||||
idx2token: HashMap<u32, Vec<u8>>,
|
||||
token2idx: HashMap<Vec<u8>, u32>,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let token2idx: HashMap<String, u32> =
|
||||
serde_json::from_reader(file).map_err(candle::Error::wrap)?;
|
||||
let token2idx = token2idx
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.into_bytes(), value))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let idx2token = token2idx
|
||||
.iter()
|
||||
.map(|(key, value)| (*value, key.to_vec()))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let max_idx = token2idx.values().copied().max().unwrap_or(0);
|
||||
|
||||
let mut table = vec![vec![vec![]; 256]; 256];
|
||||
let mut good = vec![HashSet::new(); 256];
|
||||
for idx in (0..(1 + max_idx)).rev() {
|
||||
let s = match idx2token.get(&idx) {
|
||||
None => continue,
|
||||
Some(s) => s,
|
||||
};
|
||||
if s.len() >= 2 {
|
||||
let (s0, s1) = (s[0], s[1]);
|
||||
table[s0 as usize][s1 as usize].push(s.to_vec());
|
||||
good[s0 as usize].insert(s1);
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
table,
|
||||
good,
|
||||
idx2token,
|
||||
token2idx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
|
||||
let mut v = Vec::new();
|
||||
for token_id in tokens.iter() {
|
||||
if let Some(token) = self.idx2token.get(token_id) {
|
||||
v.extend_from_slice(token.as_slice())
|
||||
}
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
let bytes = self.decode_bytes(tokens);
|
||||
String::from_utf8(bytes).map_err(candle::Error::wrap)
|
||||
}
|
||||
|
||||
pub fn encode_bytes(&self, bytes: &[u8]) -> Result<Vec<u32>> {
|
||||
let mut tokens = Vec::new();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let mut s = vec![bytes[i]];
|
||||
if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) {
|
||||
let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize];
|
||||
for table_elem in table.iter() {
|
||||
if bytes[i..].starts_with(table_elem) {
|
||||
s = table_elem.to_vec();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
i += s.len();
|
||||
let token = match self.token2idx.get(&s) {
|
||||
None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)),
|
||||
Some(token) => *token,
|
||||
};
|
||||
tokens.push(token)
|
||||
}
|
||||
Ok(tokens)
|
||||
}
|
||||
|
||||
pub fn encode(&self, str: &str) -> Result<Vec<u32>> {
|
||||
self.encode_bytes(str.as_bytes())
|
||||
}
|
||||
}
|
156
candle-transformers/src/models/vocos.rs
Normal file
156
candle-transformers/src/models/vocos.rs
Normal file
@ -0,0 +1,156 @@
|
||||
#![allow(unused)]
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
|
||||
|
||||
pub struct AdaLayerNorm {
|
||||
eps: f64,
|
||||
dim: usize,
|
||||
scale: Embedding,
|
||||
shift: Embedding,
|
||||
}
|
||||
|
||||
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let x = {
|
||||
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
};
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
|
||||
x_normed.to_dtype(x_dtype)
|
||||
}
|
||||
|
||||
impl AdaLayerNorm {
|
||||
pub fn new(
|
||||
num_embeddings: usize,
|
||||
embedding_dim: usize,
|
||||
eps: f64,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
|
||||
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
|
||||
Ok(Self {
|
||||
eps,
|
||||
dim: embedding_dim,
|
||||
scale,
|
||||
shift,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
|
||||
let scale = self.scale.forward(cond_embedding_id)?;
|
||||
let shift = self.shift.forward(cond_embedding_id)?;
|
||||
let xs = layer_norm(xs, self.eps)?;
|
||||
xs * scale + shift
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConvNeXtBlock {
|
||||
dwconv: Conv1d,
|
||||
pwconv1: Linear,
|
||||
pwconv2: Linear,
|
||||
gamma: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl ConvNeXtBlock {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
intermediate_dim: usize,
|
||||
layer_scale_init_value: f64,
|
||||
adanorm_num_embeddings: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let dwconv = {
|
||||
let cfg = Conv1dConfig {
|
||||
padding: 3,
|
||||
groups: dim,
|
||||
..Default::default()
|
||||
};
|
||||
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
|
||||
};
|
||||
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
|
||||
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
|
||||
let gamma = if layer_scale_init_value > 0. {
|
||||
Some(vb.get(dim, "gamma")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
dwconv,
|
||||
pwconv1,
|
||||
pwconv2,
|
||||
gamma,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
|
||||
// TODO: norm
|
||||
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
|
||||
let xs = match self.gamma.as_ref() {
|
||||
Some(gamma) => (gamma * xs)?,
|
||||
None => xs,
|
||||
};
|
||||
xs.transpose(1, 2)? + residual
|
||||
}
|
||||
}
|
||||
|
||||
struct VocosBackbone {
|
||||
embed: Conv1d,
|
||||
convnext: Vec<ConvNeXtBlock>,
|
||||
final_layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl VocosBackbone {
|
||||
pub fn new(
|
||||
input_channels: usize,
|
||||
dim: usize,
|
||||
intermediate_dim: usize,
|
||||
num_layers: dim,
|
||||
layer_scale_init_value: f64,
|
||||
adanorm_num_embeddings: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let embed = {
|
||||
let cfg = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
|
||||
};
|
||||
let mut convnext = Vec::with_capacity(num_layers);
|
||||
let vb_c = vb.pp("convnext");
|
||||
for i in 0..num_layers {
|
||||
let block = ConvNeXtBlock::new(
|
||||
dim,
|
||||
intermediate_dim,
|
||||
layer_scale_init_value,
|
||||
adanorm_num_embeddings,
|
||||
vb_c.pp(i),
|
||||
)?;
|
||||
}
|
||||
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
embed,
|
||||
convnext,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embed)?;
|
||||
// TODO: norm
|
||||
let mut xs = xs.transpose(1, 2)?;
|
||||
for conv_block in self.convnext.iter() {
|
||||
xs = conv_block.forward(&xs)?
|
||||
}
|
||||
xs.apply(&self.final_layer_norm)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user