Compare commits

..

4 Commits

Author SHA1 Message Date
5ac3302fac Prebuild all our kernels. 2024-03-18 16:39:38 +01:00
41416d2376 Expose more conv1d functions/structs. (#1726) 2024-02-17 18:50:55 +01:00
5ebcfeaf0f Make the r, k, v tensors contiguous. (#1719) 2024-02-16 09:17:35 +01:00
7c7400fb63 Use the tokenizer-output-stream in the llama example. (#1715)
* Use the tokenizer-output-stream in the llama example.

* Also use tokenizer-output-stream for llama2-c.
2024-02-15 16:47:33 +01:00
14 changed files with 109 additions and 213 deletions

View File

@ -588,6 +588,7 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::F16) => "cast_u32_f16",
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U8, DType::U32) => "cast_u8_u32",

View File

@ -57,7 +57,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
#[arg(long, default_value_t = 10000)]
sample_len: usize,
/// Disable the key-value cache.
@ -143,7 +143,6 @@ fn main() -> Result<()> {
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
@ -157,6 +156,7 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop");
print!("{prompt}");
@ -190,18 +190,16 @@ fn main() -> Result<()> {
token_generated += 1;
tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
if Some(next_token) == eos_token_id {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(

View File

@ -328,6 +328,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let start_gen = std::time::Instant::now();
for index in 0.. {
@ -353,16 +354,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",

View File

@ -152,7 +152,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]

View File

@ -143,7 +143,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]

2
candle-metal-kernels/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
src/compiled/

View File

@ -0,0 +1,45 @@
use std::path::Path;
use std::process::Command;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let files: std::fs::ReadDir = std::fs::read_dir("src/").unwrap();
for file in files {
let file = file?;
let path = file.path();
if let Some(extension) = path.extension() {
if extension == "metal" {
build_kernel(&path)?;
}
println!("cargo:warning=output {:?}", path.file_stem());
}
}
Ok(())
}
fn build_kernel(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
let stem = path
.file_stem()
.expect("expect real filename")
.to_str()
.expect("expect real stem");
Command::new("xcrun")
.args([
"metal",
"-c",
path.as_os_str().to_str().expect("Expect a real filename"),
"-I",
"src/",
"-o",
&format!("src/compiled/{stem}.air"),
])
.output()?;
Command::new("xcrun")
.args([
"metallib",
&format!("src/compiled/{stem}.air"),
"-o",
&format!("src/compiled/{stem}.metallib"),
])
.output()?;
Ok(())
}

View File

@ -73,6 +73,7 @@ kernel void FN_NAME_STRIDED( \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
@ -95,4 +96,4 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif
#endif

View File

@ -1,22 +1,22 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const AFFINE: &[u8] = include_bytes!("compiled/affine.metallib");
const INDEXING: &[u8] = include_bytes!("compiled/indexing.metallib");
const UNARY: &[u8] = include_bytes!("compiled/unary.metallib");
const BINARY: &[u8] = include_bytes!("compiled/binary.metallib");
const TERNARY: &[u8] = include_bytes!("compiled/ternary.metallib");
const CAST: &[u8] = include_bytes!("compiled/cast.metallib");
const CONV: &[u8] = include_bytes!("compiled/conv.metallib");
const REDUCE: &[u8] = include_bytes!("compiled/reduce.metallib");
const RANDOM: &[u8] = include_bytes!("compiled/random.metallib");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
const QUANTIZED: &[u8] = include_bytes!("compiled/quantized.metallib");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
@ -235,7 +235,7 @@ impl Kernels {
}
}
fn get_library_source(&self, source: Source) -> &'static str {
fn get_library_source(&self, source: Source) -> &'static [u8] {
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
@ -247,7 +247,7 @@ impl Kernels {
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"),
Source::Mfa => MFA,
}
}
@ -262,22 +262,12 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let lib = match source {
Source::Mfa => {
let source_data = MFA;
device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?
}
source => {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
}
};
let source_data = self.get_library_source(source);
let lib = device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?;
libraries.insert(source, lib.clone());
Ok(lib)
}

View File

@ -302,6 +302,22 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg))
}
pub fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv1dConfig,
vb: crate::VarBuilder,
) -> Result<Conv1d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
init_ws,
)?;
Ok(Conv1d::new(ws, None, cfg))
}
pub fn conv_transpose1d(
in_channels: usize,
out_channels: usize,

View File

@ -19,8 +19,9 @@ pub mod var_map;
pub use activation::{prelu, Activation, PReLU};
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use conv::{
conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias,
conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig,
ConvTranspose1d, ConvTranspose1dConfig, ConvTranspose2d, ConvTranspose2dConfig,
};
pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT};

View File

@ -42,7 +42,6 @@ pub mod t5;
pub mod trocr;
pub mod vgg;
pub mod vit;
pub mod vocos;
pub mod whisper;
pub mod with_tracing;
pub mod wuerstchen;

View File

@ -165,9 +165,9 @@ impl SelfAttention {
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
//
let rt = receptance.i((.., .., t_..t_ + 1))?;
let kt = key.i((.., .., .., t_..t_ + 1))?;
let vt = value.i((.., .., t_..t_ + 1))?;
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;

View File

@ -1,156 +0,0 @@
#![allow(unused)]
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
pub struct AdaLayerNorm {
eps: f64,
dim: usize,
scale: Embedding,
shift: Embedding,
}
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = {
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
};
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)
}
impl AdaLayerNorm {
pub fn new(
num_embeddings: usize,
embedding_dim: usize,
eps: f64,
vb: VarBuilder,
) -> Result<Self> {
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
Ok(Self {
eps,
dim: embedding_dim,
scale,
shift,
})
}
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
let scale = self.scale.forward(cond_embedding_id)?;
let shift = self.shift.forward(cond_embedding_id)?;
let xs = layer_norm(xs, self.eps)?;
xs * scale + shift
}
}
pub struct ConvNeXtBlock {
dwconv: Conv1d,
pwconv1: Linear,
pwconv2: Linear,
gamma: Option<Tensor>,
}
impl ConvNeXtBlock {
pub fn new(
dim: usize,
intermediate_dim: usize,
layer_scale_init_value: f64,
adanorm_num_embeddings: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let dwconv = {
let cfg = Conv1dConfig {
padding: 3,
groups: dim,
..Default::default()
};
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
};
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
let gamma = if layer_scale_init_value > 0. {
Some(vb.get(dim, "gamma")?)
} else {
None
};
Ok(Self {
dwconv,
pwconv1,
pwconv2,
gamma,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
// TODO: norm
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
let xs = match self.gamma.as_ref() {
Some(gamma) => (gamma * xs)?,
None => xs,
};
xs.transpose(1, 2)? + residual
}
}
struct VocosBackbone {
embed: Conv1d,
convnext: Vec<ConvNeXtBlock>,
final_layer_norm: candle_nn::LayerNorm,
}
impl VocosBackbone {
pub fn new(
input_channels: usize,
dim: usize,
intermediate_dim: usize,
num_layers: dim,
layer_scale_init_value: f64,
adanorm_num_embeddings: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let embed = {
let cfg = Conv1dConfig {
padding: 3,
..Default::default()
};
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
};
let mut convnext = Vec::with_capacity(num_layers);
let vb_c = vb.pp("convnext");
for i in 0..num_layers {
let block = ConvNeXtBlock::new(
dim,
intermediate_dim,
layer_scale_init_value,
adanorm_num_embeddings,
vb_c.pp(i),
)?;
}
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
Ok(Self {
embed,
convnext,
final_layer_norm,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embed)?;
// TODO: norm
let mut xs = xs.transpose(1, 2)?;
for conv_block in self.convnext.iter() {
xs = conv_block.forward(&xs)?
}
xs.apply(&self.final_layer_norm)
}
}