mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Compare commits
4 Commits
vocos
...
precompile
Author | SHA1 | Date | |
---|---|---|---|
5ac3302fac | |||
41416d2376 | |||
5ebcfeaf0f | |||
7c7400fb63 |
@ -588,6 +588,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||||
|
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||||
|
|
||||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||||
|
@ -57,7 +57,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 100)]
|
#[arg(long, default_value_t = 10000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// Disable the key-value cache.
|
/// Disable the key-value cache.
|
||||||
@ -143,7 +143,6 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
};
|
};
|
||||||
println!("building the model");
|
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
@ -157,6 +156,7 @@ fn main() -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
@ -190,18 +190,16 @@ fn main() -> Result<()> {
|
|||||||
token_generated += 1;
|
token_generated += 1;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
|
|
||||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
|
||||||
// heuristics as it seems to work well enough for this example. See the following for more
|
|
||||||
// details:
|
|
||||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
|
||||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
|
||||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
|
||||||
print!("{text}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
if Some(next_token) == eos_token_id {
|
if Some(next_token) == eos_token_id {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
|
@ -328,6 +328,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0.. {
|
for index in 0.. {
|
||||||
@ -353,16 +354,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
// heuristics as it seems to work well enough for this example. See the following for more
|
print!("{t}");
|
||||||
// details:
|
|
||||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
|
||||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
|
||||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
|
||||||
print!("{text}");
|
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n{} tokens generated ({:.2} token/s)\n",
|
"\n{} tokens generated ({:.2} token/s)\n",
|
||||||
|
@ -152,7 +152,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
|
@ -143,7 +143,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
||||||
|
2
candle-metal-kernels/.gitignore
vendored
Normal file
2
candle-metal-kernels/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
src/compiled/
|
||||||
|
|
45
candle-metal-kernels/build.rs
Normal file
45
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use std::path::Path;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let files: std::fs::ReadDir = std::fs::read_dir("src/").unwrap();
|
||||||
|
for file in files {
|
||||||
|
let file = file?;
|
||||||
|
let path = file.path();
|
||||||
|
if let Some(extension) = path.extension() {
|
||||||
|
if extension == "metal" {
|
||||||
|
build_kernel(&path)?;
|
||||||
|
}
|
||||||
|
println!("cargo:warning=output {:?}", path.file_stem());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_kernel(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let stem = path
|
||||||
|
.file_stem()
|
||||||
|
.expect("expect real filename")
|
||||||
|
.to_str()
|
||||||
|
.expect("expect real stem");
|
||||||
|
Command::new("xcrun")
|
||||||
|
.args([
|
||||||
|
"metal",
|
||||||
|
"-c",
|
||||||
|
path.as_os_str().to_str().expect("Expect a real filename"),
|
||||||
|
"-I",
|
||||||
|
"src/",
|
||||||
|
"-o",
|
||||||
|
&format!("src/compiled/{stem}.air"),
|
||||||
|
])
|
||||||
|
.output()?;
|
||||||
|
Command::new("xcrun")
|
||||||
|
.args([
|
||||||
|
"metallib",
|
||||||
|
&format!("src/compiled/{stem}.air"),
|
||||||
|
"-o",
|
||||||
|
&format!("src/compiled/{stem}.metallib"),
|
||||||
|
])
|
||||||
|
.output()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -73,6 +73,7 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
} \
|
} \
|
||||||
|
|
||||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||||
|
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
|
||||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
|
||||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &[u8] = include_bytes!("compiled/affine.metallib");
|
||||||
const INDEXING: &str = include_str!("indexing.metal");
|
const INDEXING: &[u8] = include_bytes!("compiled/indexing.metallib");
|
||||||
const UNARY: &str = include_str!("unary.metal");
|
const UNARY: &[u8] = include_bytes!("compiled/unary.metallib");
|
||||||
const BINARY: &str = include_str!("binary.metal");
|
const BINARY: &[u8] = include_bytes!("compiled/binary.metallib");
|
||||||
const TERNARY: &str = include_str!("ternary.metal");
|
const TERNARY: &[u8] = include_bytes!("compiled/ternary.metallib");
|
||||||
const CAST: &str = include_str!("cast.metal");
|
const CAST: &[u8] = include_bytes!("compiled/cast.metallib");
|
||||||
const CONV: &str = include_str!("conv.metal");
|
const CONV: &[u8] = include_bytes!("compiled/conv.metallib");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &[u8] = include_bytes!("compiled/reduce.metallib");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &[u8] = include_bytes!("compiled/random.metallib");
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const QUANTIZED: &[u8] = include_bytes!("compiled/quantized.metallib");
|
||||||
|
|
||||||
/// Most kernels apply similarly across the tensors
|
/// Most kernels apply similarly across the tensors
|
||||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||||
@ -235,7 +235,7 @@ impl Kernels {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_library_source(&self, source: Source) -> &'static str {
|
fn get_library_source(&self, source: Source) -> &'static [u8] {
|
||||||
match source {
|
match source {
|
||||||
Source::Affine => AFFINE,
|
Source::Affine => AFFINE,
|
||||||
Source::Unary => UNARY,
|
Source::Unary => UNARY,
|
||||||
@ -247,7 +247,7 @@ impl Kernels {
|
|||||||
Source::Conv => CONV,
|
Source::Conv => CONV,
|
||||||
Source::Random => RANDOM,
|
Source::Random => RANDOM,
|
||||||
Source::Quantized => QUANTIZED,
|
Source::Quantized => QUANTIZED,
|
||||||
Source::Mfa => panic!("Invalid lib"),
|
Source::Mfa => MFA,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,22 +262,12 @@ impl Kernels {
|
|||||||
if let Some(lib) = libraries.get(&source) {
|
if let Some(lib) = libraries.get(&source) {
|
||||||
Ok(lib.clone())
|
Ok(lib.clone())
|
||||||
} else {
|
} else {
|
||||||
let lib = match source {
|
let source_data = self.get_library_source(source);
|
||||||
Source::Mfa => {
|
let lib = device.new_library_with_data(source_data).map_err(|e| {
|
||||||
let source_data = MFA;
|
|
||||||
device.new_library_with_data(source_data).map_err(|e| {
|
|
||||||
MetalKernelError::LoadLibraryError(format!(
|
MetalKernelError::LoadLibraryError(format!(
|
||||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||||
))
|
))
|
||||||
})?
|
})?;
|
||||||
}
|
|
||||||
source => {
|
|
||||||
let source_content = self.get_library_source(source);
|
|
||||||
device
|
|
||||||
.new_library_with_source(source_content, &CompileOptions::new())
|
|
||||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
libraries.insert(source, lib.clone());
|
libraries.insert(source, lib.clone());
|
||||||
Ok(lib)
|
Ok(lib)
|
||||||
}
|
}
|
||||||
|
@ -302,6 +302,22 @@ pub fn conv1d(
|
|||||||
Ok(Conv1d::new(ws, Some(bs), cfg))
|
Ok(Conv1d::new(ws, Some(bs), cfg))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn conv1d_no_bias(
|
||||||
|
in_channels: usize,
|
||||||
|
out_channels: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
cfg: Conv1dConfig,
|
||||||
|
vb: crate::VarBuilder,
|
||||||
|
) -> Result<Conv1d> {
|
||||||
|
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||||
|
let ws = vb.get_with_hints(
|
||||||
|
(out_channels, in_channels / cfg.groups, kernel_size),
|
||||||
|
"weight",
|
||||||
|
init_ws,
|
||||||
|
)?;
|
||||||
|
Ok(Conv1d::new(ws, None, cfg))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn conv_transpose1d(
|
pub fn conv_transpose1d(
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
out_channels: usize,
|
out_channels: usize,
|
||||||
|
@ -19,8 +19,9 @@ pub mod var_map;
|
|||||||
pub use activation::{prelu, Activation, PReLU};
|
pub use activation::{prelu, Activation, PReLU};
|
||||||
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
|
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
|
||||||
pub use conv::{
|
pub use conv::{
|
||||||
conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
|
conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias,
|
||||||
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig,
|
||||||
|
ConvTranspose1d, ConvTranspose1dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
||||||
};
|
};
|
||||||
pub use embedding::{embedding, Embedding};
|
pub use embedding::{embedding, Embedding};
|
||||||
pub use func::{func, func_t, Func, FuncT};
|
pub use func::{func, func_t, Func, FuncT};
|
||||||
|
@ -42,7 +42,6 @@ pub mod t5;
|
|||||||
pub mod trocr;
|
pub mod trocr;
|
||||||
pub mod vgg;
|
pub mod vgg;
|
||||||
pub mod vit;
|
pub mod vit;
|
||||||
pub mod vocos;
|
|
||||||
pub mod whisper;
|
pub mod whisper;
|
||||||
pub mod with_tracing;
|
pub mod with_tracing;
|
||||||
pub mod wuerstchen;
|
pub mod wuerstchen;
|
||||||
|
@ -165,9 +165,9 @@ impl SelfAttention {
|
|||||||
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||||
for t_ in 0..t {
|
for t_ in 0..t {
|
||||||
//
|
//
|
||||||
let rt = receptance.i((.., .., t_..t_ + 1))?;
|
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||||
let kt = key.i((.., .., .., t_..t_ + 1))?;
|
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
|
||||||
let vt = value.i((.., .., t_..t_ + 1))?;
|
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||||
let at = kt.matmul(&vt)?;
|
let at = kt.matmul(&vt)?;
|
||||||
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||||
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||||
|
@ -1,156 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use candle::{DType, Module, Result, Tensor, D};
|
|
||||||
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
|
|
||||||
|
|
||||||
pub struct AdaLayerNorm {
|
|
||||||
eps: f64,
|
|
||||||
dim: usize,
|
|
||||||
scale: Embedding,
|
|
||||||
shift: Embedding,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
|
|
||||||
let x_dtype = x.dtype();
|
|
||||||
let internal_dtype = match x_dtype {
|
|
||||||
DType::F16 | DType::BF16 => DType::F32,
|
|
||||||
d => d,
|
|
||||||
};
|
|
||||||
let hidden_size = x.dim(D::Minus1)?;
|
|
||||||
let x = x.to_dtype(internal_dtype)?;
|
|
||||||
let x = {
|
|
||||||
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
|
||||||
x.broadcast_sub(&mean_x)?
|
|
||||||
};
|
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
|
||||||
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
|
|
||||||
x_normed.to_dtype(x_dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AdaLayerNorm {
|
|
||||||
pub fn new(
|
|
||||||
num_embeddings: usize,
|
|
||||||
embedding_dim: usize,
|
|
||||||
eps: f64,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
|
|
||||||
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
|
|
||||||
Ok(Self {
|
|
||||||
eps,
|
|
||||||
dim: embedding_dim,
|
|
||||||
scale,
|
|
||||||
shift,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
|
|
||||||
let scale = self.scale.forward(cond_embedding_id)?;
|
|
||||||
let shift = self.shift.forward(cond_embedding_id)?;
|
|
||||||
let xs = layer_norm(xs, self.eps)?;
|
|
||||||
xs * scale + shift
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ConvNeXtBlock {
|
|
||||||
dwconv: Conv1d,
|
|
||||||
pwconv1: Linear,
|
|
||||||
pwconv2: Linear,
|
|
||||||
gamma: Option<Tensor>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConvNeXtBlock {
|
|
||||||
pub fn new(
|
|
||||||
dim: usize,
|
|
||||||
intermediate_dim: usize,
|
|
||||||
layer_scale_init_value: f64,
|
|
||||||
adanorm_num_embeddings: Option<usize>,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let dwconv = {
|
|
||||||
let cfg = Conv1dConfig {
|
|
||||||
padding: 3,
|
|
||||||
groups: dim,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
|
|
||||||
};
|
|
||||||
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
|
|
||||||
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
|
|
||||||
let gamma = if layer_scale_init_value > 0. {
|
|
||||||
Some(vb.get(dim, "gamma")?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(Self {
|
|
||||||
dwconv,
|
|
||||||
pwconv1,
|
|
||||||
pwconv2,
|
|
||||||
gamma,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let residual = xs;
|
|
||||||
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
|
|
||||||
// TODO: norm
|
|
||||||
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
|
|
||||||
let xs = match self.gamma.as_ref() {
|
|
||||||
Some(gamma) => (gamma * xs)?,
|
|
||||||
None => xs,
|
|
||||||
};
|
|
||||||
xs.transpose(1, 2)? + residual
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct VocosBackbone {
|
|
||||||
embed: Conv1d,
|
|
||||||
convnext: Vec<ConvNeXtBlock>,
|
|
||||||
final_layer_norm: candle_nn::LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VocosBackbone {
|
|
||||||
pub fn new(
|
|
||||||
input_channels: usize,
|
|
||||||
dim: usize,
|
|
||||||
intermediate_dim: usize,
|
|
||||||
num_layers: dim,
|
|
||||||
layer_scale_init_value: f64,
|
|
||||||
adanorm_num_embeddings: Option<usize>,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let embed = {
|
|
||||||
let cfg = Conv1dConfig {
|
|
||||||
padding: 3,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
|
|
||||||
};
|
|
||||||
let mut convnext = Vec::with_capacity(num_layers);
|
|
||||||
let vb_c = vb.pp("convnext");
|
|
||||||
for i in 0..num_layers {
|
|
||||||
let block = ConvNeXtBlock::new(
|
|
||||||
dim,
|
|
||||||
intermediate_dim,
|
|
||||||
layer_scale_init_value,
|
|
||||||
adanorm_num_embeddings,
|
|
||||||
vb_c.pp(i),
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
|
|
||||||
Ok(Self {
|
|
||||||
embed,
|
|
||||||
convnext,
|
|
||||||
final_layer_norm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let xs = xs.apply(&self.embed)?;
|
|
||||||
// TODO: norm
|
|
||||||
let mut xs = xs.transpose(1, 2)?;
|
|
||||||
for conv_block in self.convnext.iter() {
|
|
||||||
xs = conv_block.forward(&xs)?
|
|
||||||
}
|
|
||||||
xs.apply(&self.final_layer_norm)
|
|
||||||
}
|
|
||||||
}
|
|
Reference in New Issue
Block a user