mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
4 Commits
vocos
...
precompile
Author | SHA1 | Date | |
---|---|---|---|
5ac3302fac | |||
41416d2376 | |||
5ebcfeaf0f | |||
7c7400fb63 |
@ -588,6 +588,7 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
|
@ -57,7 +57,7 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 100)]
|
||||
#[arg(long, default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Disable the key-value cache.
|
||||
@ -143,7 +143,6 @@ fn main() -> Result<()> {
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
};
|
||||
println!("building the model");
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
@ -157,6 +156,7 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
@ -190,18 +190,16 @@ fn main() -> Result<()> {
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if Some(next_token) == eos_token_id {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
|
@ -328,6 +328,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
@ -353,16 +354,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
|
@ -152,7 +152,7 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
|
@ -143,7 +143,7 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
||||
|
2
candle-metal-kernels/.gitignore
vendored
Normal file
2
candle-metal-kernels/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
src/compiled/
|
||||
|
45
candle-metal-kernels/build.rs
Normal file
45
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let files: std::fs::ReadDir = std::fs::read_dir("src/").unwrap();
|
||||
for file in files {
|
||||
let file = file?;
|
||||
let path = file.path();
|
||||
if let Some(extension) = path.extension() {
|
||||
if extension == "metal" {
|
||||
build_kernel(&path)?;
|
||||
}
|
||||
println!("cargo:warning=output {:?}", path.file_stem());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_kernel(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let stem = path
|
||||
.file_stem()
|
||||
.expect("expect real filename")
|
||||
.to_str()
|
||||
.expect("expect real stem");
|
||||
Command::new("xcrun")
|
||||
.args([
|
||||
"metal",
|
||||
"-c",
|
||||
path.as_os_str().to_str().expect("Expect a real filename"),
|
||||
"-I",
|
||||
"src/",
|
||||
"-o",
|
||||
&format!("src/compiled/{stem}.air"),
|
||||
])
|
||||
.output()?;
|
||||
Command::new("xcrun")
|
||||
.args([
|
||||
"metallib",
|
||||
&format!("src/compiled/{stem}.air"),
|
||||
"-o",
|
||||
&format!("src/compiled/{stem}.metallib"),
|
||||
])
|
||||
.output()?;
|
||||
Ok(())
|
||||
}
|
@ -73,6 +73,7 @@ kernel void FN_NAME_STRIDED( \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
@ -95,4 +96,4 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
#endif
|
||||
|
@ -1,22 +1,22 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
|
||||
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const AFFINE: &[u8] = include_bytes!("compiled/affine.metallib");
|
||||
const INDEXING: &[u8] = include_bytes!("compiled/indexing.metallib");
|
||||
const UNARY: &[u8] = include_bytes!("compiled/unary.metallib");
|
||||
const BINARY: &[u8] = include_bytes!("compiled/binary.metallib");
|
||||
const TERNARY: &[u8] = include_bytes!("compiled/ternary.metallib");
|
||||
const CAST: &[u8] = include_bytes!("compiled/cast.metallib");
|
||||
const CONV: &[u8] = include_bytes!("compiled/conv.metallib");
|
||||
const REDUCE: &[u8] = include_bytes!("compiled/reduce.metallib");
|
||||
const RANDOM: &[u8] = include_bytes!("compiled/random.metallib");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const QUANTIZED: &[u8] = include_bytes!("compiled/quantized.metallib");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
@ -235,7 +235,7 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
fn get_library_source(&self, source: Source) -> &'static [u8] {
|
||||
match source {
|
||||
Source::Affine => AFFINE,
|
||||
Source::Unary => UNARY,
|
||||
@ -247,7 +247,7 @@ impl Kernels {
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
Source::Mfa => MFA,
|
||||
}
|
||||
}
|
||||
|
||||
@ -262,22 +262,12 @@ impl Kernels {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
};
|
||||
let source_data = self.get_library_source(source);
|
||||
let lib = device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?;
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
}
|
||||
|
@ -302,6 +302,22 @@ pub fn conv1d(
|
||||
Ok(Conv1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
pub fn conv1d_no_bias(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv1dConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vb.get_with_hints(
|
||||
(out_channels, in_channels / cfg.groups, kernel_size),
|
||||
"weight",
|
||||
init_ws,
|
||||
)?;
|
||||
Ok(Conv1d::new(ws, None, cfg))
|
||||
}
|
||||
|
||||
pub fn conv_transpose1d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
|
@ -19,8 +19,9 @@ pub mod var_map;
|
||||
pub use activation::{prelu, Activation, PReLU};
|
||||
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
|
||||
pub use conv::{
|
||||
conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
|
||||
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
||||
conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias,
|
||||
conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig,
|
||||
ConvTranspose1d, ConvTranspose1dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
||||
};
|
||||
pub use embedding::{embedding, Embedding};
|
||||
pub use func::{func, func_t, Func, FuncT};
|
||||
|
@ -42,7 +42,6 @@ pub mod t5;
|
||||
pub mod trocr;
|
||||
pub mod vgg;
|
||||
pub mod vit;
|
||||
pub mod vocos;
|
||||
pub mod whisper;
|
||||
pub mod with_tracing;
|
||||
pub mod wuerstchen;
|
||||
|
@ -165,9 +165,9 @@ impl SelfAttention {
|
||||
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||
for t_ in 0..t {
|
||||
//
|
||||
let rt = receptance.i((.., .., t_..t_ + 1))?;
|
||||
let kt = key.i((.., .., .., t_..t_ + 1))?;
|
||||
let vt = value.i((.., .., t_..t_ + 1))?;
|
||||
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
|
||||
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
let at = kt.matmul(&vt)?;
|
||||
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||
|
@ -1,156 +0,0 @@
|
||||
#![allow(unused)]
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
|
||||
|
||||
pub struct AdaLayerNorm {
|
||||
eps: f64,
|
||||
dim: usize,
|
||||
scale: Embedding,
|
||||
shift: Embedding,
|
||||
}
|
||||
|
||||
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let x = {
|
||||
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
};
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
|
||||
x_normed.to_dtype(x_dtype)
|
||||
}
|
||||
|
||||
impl AdaLayerNorm {
|
||||
pub fn new(
|
||||
num_embeddings: usize,
|
||||
embedding_dim: usize,
|
||||
eps: f64,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
|
||||
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
|
||||
Ok(Self {
|
||||
eps,
|
||||
dim: embedding_dim,
|
||||
scale,
|
||||
shift,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
|
||||
let scale = self.scale.forward(cond_embedding_id)?;
|
||||
let shift = self.shift.forward(cond_embedding_id)?;
|
||||
let xs = layer_norm(xs, self.eps)?;
|
||||
xs * scale + shift
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConvNeXtBlock {
|
||||
dwconv: Conv1d,
|
||||
pwconv1: Linear,
|
||||
pwconv2: Linear,
|
||||
gamma: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl ConvNeXtBlock {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
intermediate_dim: usize,
|
||||
layer_scale_init_value: f64,
|
||||
adanorm_num_embeddings: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let dwconv = {
|
||||
let cfg = Conv1dConfig {
|
||||
padding: 3,
|
||||
groups: dim,
|
||||
..Default::default()
|
||||
};
|
||||
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
|
||||
};
|
||||
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
|
||||
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
|
||||
let gamma = if layer_scale_init_value > 0. {
|
||||
Some(vb.get(dim, "gamma")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
dwconv,
|
||||
pwconv1,
|
||||
pwconv2,
|
||||
gamma,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
|
||||
// TODO: norm
|
||||
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
|
||||
let xs = match self.gamma.as_ref() {
|
||||
Some(gamma) => (gamma * xs)?,
|
||||
None => xs,
|
||||
};
|
||||
xs.transpose(1, 2)? + residual
|
||||
}
|
||||
}
|
||||
|
||||
struct VocosBackbone {
|
||||
embed: Conv1d,
|
||||
convnext: Vec<ConvNeXtBlock>,
|
||||
final_layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl VocosBackbone {
|
||||
pub fn new(
|
||||
input_channels: usize,
|
||||
dim: usize,
|
||||
intermediate_dim: usize,
|
||||
num_layers: dim,
|
||||
layer_scale_init_value: f64,
|
||||
adanorm_num_embeddings: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let embed = {
|
||||
let cfg = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
|
||||
};
|
||||
let mut convnext = Vec::with_capacity(num_layers);
|
||||
let vb_c = vb.pp("convnext");
|
||||
for i in 0..num_layers {
|
||||
let block = ConvNeXtBlock::new(
|
||||
dim,
|
||||
intermediate_dim,
|
||||
layer_scale_init_value,
|
||||
adanorm_num_embeddings,
|
||||
vb_c.pp(i),
|
||||
)?;
|
||||
}
|
||||
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
embed,
|
||||
convnext,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embed)?;
|
||||
// TODO: norm
|
||||
let mut xs = xs.transpose(1, 2)?;
|
||||
for conv_block in self.convnext.iter() {
|
||||
xs = conv_block.forward(&xs)?
|
||||
}
|
||||
xs.apply(&self.final_layer_norm)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user