mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Add the mimi audio-tokenizer. (#2488)
* Add the mimi audio-tokenizer. * Formatting tweaks. * Add a full example. * Use the transformers names. * More renamings. * Get encoding and decoding to work. * Clippy fixes.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@ -43,3 +43,6 @@ candle-wasm-examples/**/config*.json
|
||||
__pycache__
|
||||
out.safetensors
|
||||
out.wav
|
||||
bria.mp3
|
||||
bria.safetensors
|
||||
bria.wav
|
||||
|
@ -67,6 +67,7 @@ onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
mimi = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
|
||||
[[example]]
|
||||
@ -101,6 +102,10 @@ required-features = ["candle-datasets"]
|
||||
name = "llama2-c"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "mimi"
|
||||
required-features = ["mimi"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["encodec"]
|
||||
|
20
candle-examples/examples/mimi/README.md
Normal file
20
candle-examples/examples/mimi/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-mimi
|
||||
|
||||
[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio
|
||||
compression model using an encoder/decoder architecture with residual vector
|
||||
quantization. The candle implementation supports streaming meaning that it's
|
||||
possible to encode or decode a stream of audio tokens on the flight to provide
|
||||
low latency interaction with an audio model.
|
||||
|
||||
## Running one example
|
||||
|
||||
Generating some audio tokens from an audio files.
|
||||
```bash
|
||||
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
||||
cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors
|
||||
```
|
||||
|
||||
And decoding the audio tokens back into a sound file.
|
||||
```bash
|
||||
cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav
|
||||
```
|
275
candle-examples/examples/mimi/audio_io.rs
Normal file
275
candle-examples/examples/mimi/audio_io.rs
Normal file
@ -0,0 +1,275 @@
|
||||
#![allow(unused)]
|
||||
use anyhow::{Context, Result};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub const SAMPLE_RATE: usize = 24_000;
|
||||
|
||||
pub(crate) struct AudioOutputData_ {
|
||||
resampled_data: std::collections::VecDeque<f32>,
|
||||
resampler: rubato::FastFixedIn<f32>,
|
||||
output_buffer: Vec<f32>,
|
||||
input_buffer: Vec<f32>,
|
||||
input_len: usize,
|
||||
}
|
||||
|
||||
impl AudioOutputData_ {
|
||||
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
||||
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
||||
let resampler = rubato::FastFixedIn::new(
|
||||
resample_ratio,
|
||||
f64::max(resample_ratio, 1.0),
|
||||
rubato::PolynomialDegree::Septic,
|
||||
1024,
|
||||
1,
|
||||
)?;
|
||||
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
||||
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
||||
Ok(Self {
|
||||
resampled_data,
|
||||
resampler,
|
||||
input_buffer,
|
||||
output_buffer,
|
||||
input_len: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
use rubato::Resampler;
|
||||
self.output_buffer.fill(0.);
|
||||
self.input_buffer.fill(0.);
|
||||
self.resampler.reset();
|
||||
self.resampled_data.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
||||
let mut data = Vec::with_capacity(self.resampled_data.len());
|
||||
while let Some(elem) = self.resampled_data.pop_back() {
|
||||
data.push(elem);
|
||||
}
|
||||
data
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.resampled_data.is_empty()
|
||||
}
|
||||
|
||||
// Assumes that the input buffer is large enough.
|
||||
fn push_input_buffer(&mut self, samples: &[f32]) {
|
||||
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
||||
self.input_len += samples.len()
|
||||
}
|
||||
|
||||
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pos_in = 0;
|
||||
loop {
|
||||
let rem = self.input_buffer.len() - self.input_len;
|
||||
let pos_end = usize::min(pos_in + rem, samples.len());
|
||||
self.push_input_buffer(&samples[pos_in..pos_end]);
|
||||
pos_in = pos_end;
|
||||
if self.input_len < self.input_buffer.len() {
|
||||
break;
|
||||
}
|
||||
let (_, out_len) = self.resampler.process_into_buffer(
|
||||
&[&self.input_buffer],
|
||||
&mut [&mut self.output_buffer],
|
||||
None,
|
||||
)?;
|
||||
for &elem in self.output_buffer[..out_len].iter() {
|
||||
self.resampled_data.push_front(elem)
|
||||
}
|
||||
self.input_len = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
||||
|
||||
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio output stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_output_device()
|
||||
.context("no output device available")?;
|
||||
let mut supported_configs_range = device.supported_output_configs()?;
|
||||
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
||||
// On macOS, it's commonly the case that there are only stereo outputs.
|
||||
None => device
|
||||
.supported_output_configs()?
|
||||
.next()
|
||||
.context("no audio output available")?,
|
||||
Some(config_range) => config_range,
|
||||
};
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
let channels = config.channels as usize;
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
SAMPLE_RATE,
|
||||
config.sample_rate.0 as usize,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||
data.fill(0.);
|
||||
let mut ad = ad.lock().unwrap();
|
||||
let mut last_elem = 0f32;
|
||||
for (idx, elem) in data.iter_mut().enumerate() {
|
||||
if idx % channels == 0 {
|
||||
match ad.resampled_data.pop_back() {
|
||||
None => break,
|
||||
Some(v) => {
|
||||
last_elem = v;
|
||||
*elem = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*elem = last_elem
|
||||
}
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio input stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_input_device()
|
||||
.context("no input device available")?;
|
||||
let mut supported_configs_range = device.supported_input_configs()?;
|
||||
let config_range = supported_configs_range
|
||||
.find(|c| c.channels() == 1)
|
||||
.context("no audio input available")?;
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
config.sample_rate.0 as usize,
|
||||
SAMPLE_RATE,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let mut ad = ad.lock().unwrap();
|
||||
if let Err(err) = ad.push_samples(data) {
|
||||
eprintln!("error processing audio input {err:?}")
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pcm_out =
|
||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||
|
||||
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
|
||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||
let mut pos_in = 0;
|
||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||
let (in_len, out_len) =
|
||||
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
||||
pos_in += in_len;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
if pos_in < pcm_in.len() {
|
||||
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
||||
Some(&[&pcm_in[pos_in..]]),
|
||||
&mut output_buffer,
|
||||
None,
|
||||
)?;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
Ok(pcm_out)
|
||||
}
|
131
candle-examples/examples/mimi/main.rs
Normal file
131
candle-examples/examples/mimi/main.rs
Normal file
@ -0,0 +1,131 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::mimi::{Config, Model};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
mod audio_io;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Action {
|
||||
AudioToAudio,
|
||||
AudioToCode,
|
||||
CodeToAudio,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The action to be performed, specifies the format for the input and output data.
|
||||
action: Action,
|
||||
|
||||
/// The input file, either an audio file or some mimi tokens stored as safetensors.
|
||||
in_file: String,
|
||||
|
||||
/// The output file, either a wave audio file or some mimi tokens stored as safetensors.
|
||||
out_file: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.model("kyutai/mimi".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let config = Config::v0_1(None);
|
||||
let mut model = Model::new(config, vb)?;
|
||||
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
codes.get("codes").expect("no codes in input file").clone()
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let pcm = if args.in_file == "-" {
|
||||
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
||||
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
||||
let mut pcms = vec![];
|
||||
let stdin = std::thread::spawn(|| {
|
||||
let mut s = String::new();
|
||||
std::io::stdin().read_line(&mut s)
|
||||
});
|
||||
while !stdin.is_finished() {
|
||||
let input = input_audio.lock().unwrap().take_all();
|
||||
if input.is_empty() {
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
continue;
|
||||
}
|
||||
pcms.push(input)
|
||||
}
|
||||
drop(stream);
|
||||
pcms.concat()
|
||||
} else {
|
||||
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
||||
if sample_rate != 24_000 {
|
||||
println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling...");
|
||||
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
|
||||
} else {
|
||||
pcm
|
||||
}
|
||||
};
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
}
|
||||
};
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
|
||||
match args.action {
|
||||
Action::AudioToCode => {
|
||||
codes.save_safetensors("codes", &args.out_file)?;
|
||||
}
|
||||
Action::AudioToAudio | Action::CodeToAudio => {
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
if args.out_file == "-" {
|
||||
let (stream, ad) = audio_io::setup_output_stream()?;
|
||||
{
|
||||
let mut ad = ad.lock().unwrap();
|
||||
ad.push_samples(&pcm)?;
|
||||
}
|
||||
loop {
|
||||
let ad = ad.lock().unwrap();
|
||||
if ad.is_empty() {
|
||||
break;
|
||||
}
|
||||
// That's very weird, calling thread::sleep here triggers the stream to stop
|
||||
// playing (the callback doesn't seem to be called anymore).
|
||||
// std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
}
|
||||
drop(stream)
|
||||
} else {
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
670
candle-transformers/src/models/mimi/conv.rs
Normal file
670
candle-transformers/src/models/mimi/conv.rs
Normal file
@ -0,0 +1,670 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use candle::{Module, Result, StreamTensor, StreamingModule, Tensor, D};
|
||||
use candle_nn::{Conv1d, VarBuilder};
|
||||
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Norm {
|
||||
WeightNorm,
|
||||
SpectralNorm,
|
||||
TimeGroupNorm,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum PadMode {
|
||||
Constant,
|
||||
Reflect,
|
||||
Replicate,
|
||||
}
|
||||
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
fn conv1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
bias: bool,
|
||||
config: candle_nn::Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = if vb.contains_tensor("weight") {
|
||||
vb.get((out_c, in_c, kernel_size), "weight")?
|
||||
} else {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
|
||||
};
|
||||
let bias = if bias {
|
||||
Some(vb.get(out_c, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Conv1d::new(weight, bias, config))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NormConv1d {
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl NormConv1d {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k_size: usize,
|
||||
causal: bool,
|
||||
norm: Option<Norm>,
|
||||
bias: bool,
|
||||
cfg: candle_nn::Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let conv = match norm {
|
||||
None | Some(Norm::TimeGroupNorm) => {
|
||||
if bias {
|
||||
candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
|
||||
} else {
|
||||
candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
|
||||
}
|
||||
}
|
||||
Some(Norm::WeightNorm) => {
|
||||
conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
|
||||
}
|
||||
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
|
||||
};
|
||||
let norm = match norm {
|
||||
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
|
||||
Some(Norm::TimeGroupNorm) => {
|
||||
if causal {
|
||||
candle::bail!("GroupNorm doesn't support causal evaluation.")
|
||||
}
|
||||
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(norm)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
conv,
|
||||
norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for NormConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = xs.apply(&self.conv)?;
|
||||
match self.norm.as_ref() {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NormConvTranspose1d {
|
||||
ws: Tensor,
|
||||
bs: Option<Tensor>,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl NormConvTranspose1d {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k_size: usize,
|
||||
causal: bool,
|
||||
norm: Option<Norm>,
|
||||
bias: bool,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("conv");
|
||||
let bs = if bias {
|
||||
Some(vb.get(out_c, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let ws = match norm {
|
||||
None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
|
||||
Some(Norm::WeightNorm) => {
|
||||
if vb.contains_tensor("weight") {
|
||||
vb.get((in_c, out_c, k_size), "weight")?
|
||||
} else {
|
||||
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
|
||||
}
|
||||
}
|
||||
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
|
||||
};
|
||||
let (ws, groups) = if groups == out_c && in_c == out_c {
|
||||
let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
|
||||
let ws = ws
|
||||
.repeat((1, out_c, 1))?
|
||||
.mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
|
||||
(ws, 1)
|
||||
} else {
|
||||
(ws, groups)
|
||||
};
|
||||
let norm = match norm {
|
||||
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
|
||||
Some(Norm::TimeGroupNorm) => {
|
||||
if causal {
|
||||
candle::bail!("GroupNorm doesn't support causal evaluation.")
|
||||
}
|
||||
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(norm)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
ws,
|
||||
bs,
|
||||
k_size,
|
||||
stride,
|
||||
groups,
|
||||
norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for NormConvTranspose1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
// conv-transpose1d seems to be broken on metal after enough iterations. Causing
|
||||
// the following error:
|
||||
// _status < MTLCommandBufferStatusCommitted >
|
||||
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
|
||||
// This is now fixed in candle.
|
||||
let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
|
||||
let xs = match &self.bs {
|
||||
None => xs,
|
||||
Some(bias) => {
|
||||
let b = bias.dims1()?;
|
||||
let bias = bias.reshape((1, b, 1))?;
|
||||
xs.broadcast_add(&bias)?
|
||||
}
|
||||
};
|
||||
match self.norm.as_ref() {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_extra_padding_for_conv1d(
|
||||
xs: &Tensor,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
padding_total: usize,
|
||||
) -> Result<usize> {
|
||||
let len = xs.dim(D::Minus1)?;
|
||||
let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
|
||||
let ideal_len =
|
||||
((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
|
||||
Ok(ideal_len.saturating_sub(len))
|
||||
}
|
||||
|
||||
fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
|
||||
match mode {
|
||||
PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
|
||||
PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
|
||||
PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
|
||||
}
|
||||
}
|
||||
|
||||
fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
|
||||
let len = xs.dim(D::Minus1)?;
|
||||
if len < unpad_l + unpad_r {
|
||||
candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
|
||||
}
|
||||
xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamableConv1d {
|
||||
conv: NormConv1d,
|
||||
causal: bool,
|
||||
pad_mode: PadMode,
|
||||
state_prev_xs: StreamTensor,
|
||||
left_pad_applied: bool,
|
||||
kernel_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamableConv1d {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
bias: bool,
|
||||
causal: bool,
|
||||
norm: Option<Norm>,
|
||||
pad_mode: PadMode,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let cfg = candle_nn::Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
};
|
||||
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
|
||||
if k_size < stride {
|
||||
candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
|
||||
}
|
||||
Ok(Self {
|
||||
conv,
|
||||
causal,
|
||||
pad_mode,
|
||||
state_prev_xs: StreamTensor::empty(),
|
||||
left_pad_applied: false,
|
||||
kernel_size: k_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for StreamableConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_b, _t, _c) = xs.dims3()?;
|
||||
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
|
||||
let conv_cfg = self.conv.conv.config();
|
||||
// Effective kernel size with dilations.
|
||||
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
|
||||
let padding_total = k_size - conv_cfg.stride;
|
||||
let extra_padding =
|
||||
get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
|
||||
let xs = if self.causal {
|
||||
pad1d(xs, padding_total, extra_padding, self.pad_mode)?
|
||||
} else {
|
||||
let padding_right = padding_total / 2;
|
||||
let padding_left = padding_total - padding_right;
|
||||
pad1d(
|
||||
xs,
|
||||
padding_left,
|
||||
padding_right + extra_padding,
|
||||
self.pad_mode,
|
||||
)?
|
||||
};
|
||||
xs.apply(&self.conv)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for StreamableConv1d {
|
||||
fn reset_state(&mut self) {
|
||||
self.state_prev_xs.reset();
|
||||
self.left_pad_applied = false;
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = match xs.as_option() {
|
||||
None => return Ok(().into()),
|
||||
Some(xs) => xs.clone(),
|
||||
};
|
||||
let xs = if self.left_pad_applied {
|
||||
xs
|
||||
} else {
|
||||
self.left_pad_applied = true;
|
||||
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
|
||||
let conv_cfg = self.conv.conv.config();
|
||||
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
|
||||
let padding_total = k_size - conv_cfg.stride;
|
||||
pad1d(&xs, padding_total, 0, self.pad_mode)?
|
||||
};
|
||||
let cfg = self.conv.conv.config();
|
||||
let stride = cfg.stride;
|
||||
let dilation = cfg.dilation;
|
||||
let kernel = (self.kernel_size - 1) * dilation + 1;
|
||||
let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
|
||||
let seq_len = xs.seq_len(D::Minus1)?;
|
||||
let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
|
||||
if num_frames > 0 {
|
||||
let offset = num_frames * stride;
|
||||
self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
|
||||
let in_l = (num_frames - 1) * stride + kernel;
|
||||
let xs = xs.narrow(D::Minus1, 0, in_l)?;
|
||||
// We apply the underlying convtr directly rather than through forward so as
|
||||
// not to apply any padding here.
|
||||
xs.apply(&self.conv.conv)
|
||||
} else {
|
||||
self.state_prev_xs = xs;
|
||||
Ok(StreamTensor::empty())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamableConvTranspose1d {
|
||||
convtr: NormConvTranspose1d,
|
||||
causal: bool,
|
||||
state_prev_ys: StreamTensor,
|
||||
kernel_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamableConvTranspose1d {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
bias: bool,
|
||||
causal: bool,
|
||||
norm: Option<Norm>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let convtr =
|
||||
NormConvTranspose1d::new(in_c, out_c, k_size, causal, norm, bias, stride, groups, vb)?;
|
||||
Ok(Self {
|
||||
convtr,
|
||||
causal,
|
||||
kernel_size: k_size,
|
||||
state_prev_ys: StreamTensor::empty(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for StreamableConvTranspose1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let k_size = self.convtr.k_size;
|
||||
let stride = self.convtr.stride;
|
||||
let padding_total = k_size.saturating_sub(stride);
|
||||
let xs = xs.apply(&self.convtr)?;
|
||||
if self.causal {
|
||||
// This corresponds to trim_right_ratio = 1.
|
||||
unpad1d(&xs, 0, padding_total)
|
||||
} else {
|
||||
let padding_right = padding_total / 2;
|
||||
let padding_left = padding_total - padding_right;
|
||||
unpad1d(&xs, padding_left, padding_right)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for StreamableConvTranspose1d {
|
||||
fn reset_state(&mut self) {
|
||||
self.state_prev_ys.reset()
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = match xs.as_option() {
|
||||
Some(xs) => xs,
|
||||
None => return Ok(StreamTensor::empty()),
|
||||
};
|
||||
let stride = self.convtr.stride;
|
||||
// We apply the underlying convtr directly rather than through forward so as
|
||||
// not to apply any padding here.
|
||||
let ys = self.convtr.forward(xs)?;
|
||||
let ot = ys.dim(D::Minus1)?;
|
||||
let ys = match self.state_prev_ys.as_option() {
|
||||
None => ys,
|
||||
Some(prev_ys) => {
|
||||
let pt = prev_ys.dim(D::Minus1)?;
|
||||
// Remove the bias as it will be applied multiple times.
|
||||
let prev_ys = match &self.convtr.bs {
|
||||
None => prev_ys.clone(),
|
||||
Some(bias) => {
|
||||
let bias = bias.reshape((1, (), 1))?;
|
||||
prev_ys.broadcast_sub(&bias)?
|
||||
}
|
||||
};
|
||||
let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
|
||||
let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
|
||||
Tensor::cat(&[ys1, ys2], D::Minus1)?
|
||||
}
|
||||
};
|
||||
let invalid_steps = self.kernel_size - stride;
|
||||
let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
|
||||
self.state_prev_ys = prev_ys;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConvDownsample1d {
|
||||
conv: StreamableConv1d,
|
||||
}
|
||||
|
||||
impl ConvDownsample1d {
|
||||
pub fn new(
|
||||
stride: usize,
|
||||
dim: usize,
|
||||
causal: bool,
|
||||
learnt: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
if !learnt {
|
||||
candle::bail!("only learnt=true is supported")
|
||||
}
|
||||
let conv = StreamableConv1d::new(
|
||||
/* in_c */ dim,
|
||||
/* out_c */ dim,
|
||||
/* k_size_c */ 2 * stride,
|
||||
/* stride */ stride,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1, // channel_wise = false
|
||||
/* bias */ false,
|
||||
/* causal */ causal,
|
||||
/* norm */ None,
|
||||
/* pad_mode */ PadMode::Replicate,
|
||||
vb,
|
||||
)?;
|
||||
Ok(Self { conv })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvDownsample1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.conv)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for ConvDownsample1d {
|
||||
fn reset_state(&mut self) {
|
||||
self.conv.reset_state()
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
self.conv.step(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConvTrUpsample1d {
|
||||
convtr: StreamableConvTranspose1d,
|
||||
}
|
||||
|
||||
impl ConvTrUpsample1d {
|
||||
pub fn new(
|
||||
stride: usize,
|
||||
dim: usize,
|
||||
causal: bool,
|
||||
learnt: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
if !learnt {
|
||||
candle::bail!("only learnt=true is supported")
|
||||
}
|
||||
let convtr = StreamableConvTranspose1d::new(
|
||||
dim,
|
||||
dim,
|
||||
/* k_size */ 2 * stride,
|
||||
/* stride */ stride,
|
||||
/* groups */ dim,
|
||||
/* bias */ false,
|
||||
/* causal */ causal,
|
||||
/* norm */ None,
|
||||
vb,
|
||||
)?;
|
||||
Ok(Self { convtr })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvTrUpsample1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.convtr)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for ConvTrUpsample1d {
|
||||
fn reset_state(&mut self) {
|
||||
self.convtr.reset_state()
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
self.convtr.step(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle::IndexOp;
|
||||
|
||||
fn run_conv1d(
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
step_size: usize,
|
||||
len: usize,
|
||||
bias: bool,
|
||||
) -> Result<()> {
|
||||
// TODO: We should ensure for the seed to be constant when running these tests.
|
||||
let dev = &candle::Device::Cpu;
|
||||
let vm = candle_nn::VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
|
||||
let conv1d = StreamableConv1d::new(
|
||||
/* in_c */ 2,
|
||||
/* out_c */ 3,
|
||||
/* k_size */ k_size,
|
||||
/* stride */ stride,
|
||||
/* dilation */ dilation,
|
||||
/* groups */ 1,
|
||||
/* bias */ bias,
|
||||
/* causal */ true,
|
||||
/* norm */ None,
|
||||
/* pad_mode */ PadMode::Constant,
|
||||
vb,
|
||||
)?;
|
||||
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
|
||||
let ys = conv1d.forward(&xs)?;
|
||||
let mut conv1d = conv1d;
|
||||
let mut ys_steps = vec![];
|
||||
for idx in 0..len {
|
||||
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
|
||||
let ys = conv1d.step(&xs.into())?;
|
||||
if let Some(ys) = ys.as_option() {
|
||||
ys_steps.push(ys.clone())
|
||||
}
|
||||
}
|
||||
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
|
||||
let diff = (&ys - &ys_steps)?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?
|
||||
.to_vec0::<f32>()?;
|
||||
if diff > 1e-5 {
|
||||
println!("{xs}");
|
||||
println!("{ys}");
|
||||
println!("{ys_steps}");
|
||||
candle::bail!("larger diff than expected {diff}")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_conv_tr1d(
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
step_size: usize,
|
||||
len: usize,
|
||||
bias: bool,
|
||||
) -> Result<()> {
|
||||
// TODO: We should ensure for the seed to be constant when running these tests.
|
||||
let dev = &candle::Device::Cpu;
|
||||
let vm = candle_nn::VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
|
||||
let conv1d = StreamableConvTranspose1d::new(
|
||||
/* in_c */ 2, /* out_c */ 3, /* k_size */ k_size,
|
||||
/* stride */ stride, /* groups */ 1, /* bias */ bias,
|
||||
/* causal */ true, /* norm */ None, vb,
|
||||
)?;
|
||||
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
|
||||
let ys = conv1d.forward(&xs)?;
|
||||
let mut conv1d = conv1d;
|
||||
let mut ys_steps = vec![];
|
||||
for idx in 0..len {
|
||||
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
|
||||
let ys = conv1d.step(&xs.into())?;
|
||||
if let Some(ys) = ys.as_option() {
|
||||
ys_steps.push(ys.clone())
|
||||
}
|
||||
}
|
||||
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
|
||||
let diff = (&ys - &ys_steps)?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?
|
||||
.to_vec0::<f32>()?;
|
||||
if diff > 1e-5 {
|
||||
println!("{xs}");
|
||||
println!("{ys}");
|
||||
println!("{ys_steps}");
|
||||
candle::bail!("larger diff than expected {diff}")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv1d() -> Result<()> {
|
||||
for step_size in [1, 2, 3] {
|
||||
for bias in [false, true] {
|
||||
run_conv1d(1, 1, 1, step_size, 5, bias)?;
|
||||
run_conv1d(2, 1, 1, step_size, 5, bias)?;
|
||||
run_conv1d(2, 2, 1, step_size, 6, bias)?;
|
||||
run_conv1d(3, 2, 1, step_size, 8, bias)?;
|
||||
run_conv1d(3, 2, 2, step_size, 8, bias)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_tr1d() -> Result<()> {
|
||||
for step_size in [1, 2, 3] {
|
||||
for bias in [false, true] {
|
||||
run_conv_tr1d(1, 1, step_size, 5, bias)?;
|
||||
run_conv_tr1d(2, 1, step_size, 5, bias)?;
|
||||
run_conv_tr1d(3, 1, step_size, 5, bias)?;
|
||||
run_conv_tr1d(3, 2, step_size, 5, bias)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
229
candle-transformers/src/models/mimi/encodec.rs
Normal file
229
candle-transformers/src/models/mimi/encodec.rs
Normal file
@ -0,0 +1,229 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use super::{conv, quantization, seanet, transformer};
|
||||
use candle::{DType, Device, Module, Result, StreamTensor, StreamingModule, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum ResampleMethod {
|
||||
Conv,
|
||||
Interpolate,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub channels: usize,
|
||||
pub sample_rate: f64,
|
||||
pub frame_rate: f64,
|
||||
pub renormalize: bool,
|
||||
pub resample_method: ResampleMethod,
|
||||
pub seanet: seanet::Config,
|
||||
pub transformer: transformer::Config,
|
||||
pub quantizer_n_q: usize,
|
||||
pub quantizer_bins: usize,
|
||||
pub quantizer_dim: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/b7d2bd5a/.hydra/config.yaml
|
||||
pub fn v0_1(num_codebooks: Option<usize>) -> Self {
|
||||
let seanet_cfg = seanet::Config {
|
||||
dimension: 512,
|
||||
channels: 1,
|
||||
causal: true,
|
||||
n_filters: 64,
|
||||
n_residual_layers: 1,
|
||||
activation: candle_nn::Activation::Elu(1.),
|
||||
compress: 2,
|
||||
dilation_base: 2,
|
||||
disable_norm_outer_blocks: 0,
|
||||
final_activation: None,
|
||||
kernel_size: 7,
|
||||
residual_kernel_size: 3,
|
||||
last_kernel_size: 3,
|
||||
lstm: 0,
|
||||
norm: conv::Norm::WeightNorm,
|
||||
pad_mode: conv::PadMode::Constant,
|
||||
ratios: vec![8, 6, 5, 4],
|
||||
true_skip: true,
|
||||
};
|
||||
let transformer_cfg = transformer::Config {
|
||||
d_model: seanet_cfg.dimension,
|
||||
num_heads: 8,
|
||||
num_layers: 8,
|
||||
causal: true,
|
||||
norm_first: true,
|
||||
bias_ff: false,
|
||||
bias_attn: false,
|
||||
layer_scale: Some(0.01),
|
||||
context: 250,
|
||||
conv_kernel_size: 5,
|
||||
use_conv_bias: true,
|
||||
use_conv_block: false,
|
||||
cross_attention: false,
|
||||
max_period: 10000,
|
||||
gating: None,
|
||||
norm: super::NormType::LayerNorm,
|
||||
positional_embedding: transformer::PositionalEmbedding::Rope,
|
||||
|
||||
dim_feedforward: 2048,
|
||||
kv_repeat: 1,
|
||||
conv_layout: true, // see builders.py
|
||||
max_seq_len: 8192, // the transformer works at 25hz so this is ~5 mins.
|
||||
};
|
||||
Config {
|
||||
channels: 1,
|
||||
sample_rate: 24_000.,
|
||||
frame_rate: 12.5,
|
||||
renormalize: true,
|
||||
resample_method: ResampleMethod::Conv,
|
||||
seanet: seanet_cfg,
|
||||
transformer: transformer_cfg,
|
||||
quantizer_n_q: num_codebooks.unwrap_or(16),
|
||||
quantizer_bins: 2048,
|
||||
quantizer_dim: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Encodec {
|
||||
encoder: seanet::SeaNetEncoder,
|
||||
decoder: seanet::SeaNetDecoder,
|
||||
encoder_transformer: transformer::ProjectedTransformer,
|
||||
decoder_transformer: transformer::ProjectedTransformer,
|
||||
downsample: conv::ConvDownsample1d,
|
||||
upsample: conv::ConvTrUpsample1d,
|
||||
quantizer: quantization::SplitResidualVectorQuantizer,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl Encodec {
|
||||
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dim = cfg.seanet.dimension;
|
||||
let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?;
|
||||
let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?;
|
||||
let encoder_transformer = transformer::ProjectedTransformer::new(
|
||||
dim,
|
||||
&[dim],
|
||||
&cfg.transformer,
|
||||
vb.pp("encoder_transformer"),
|
||||
)?;
|
||||
let decoder_transformer = transformer::ProjectedTransformer::new(
|
||||
dim,
|
||||
&[dim],
|
||||
&cfg.transformer,
|
||||
vb.pp("decoder_transformer"),
|
||||
)?;
|
||||
let quantizer = quantization::SplitResidualVectorQuantizer::new(
|
||||
/* dim */ cfg.quantizer_dim,
|
||||
/* input_dim */ Some(dim),
|
||||
/* output_dim */ Some(dim),
|
||||
/* n_q */ cfg.quantizer_n_q,
|
||||
/* bins */ cfg.quantizer_bins,
|
||||
vb.pp("quantizer"),
|
||||
)?;
|
||||
let encoder_frame_rate =
|
||||
cfg.sample_rate / cfg.seanet.ratios.iter().product::<usize>() as f64;
|
||||
|
||||
let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize;
|
||||
// `upsample` and `downsample` only apply if frame_rate is different from encoder_frame_rate.
|
||||
let downsample = conv::ConvDownsample1d::new(
|
||||
/* stride */ downsample_stride,
|
||||
/* dim */ dim,
|
||||
/* causal */ true,
|
||||
/* learnt */ true,
|
||||
vb.pp("downsample"),
|
||||
)?;
|
||||
let upsample = conv::ConvTrUpsample1d::new(
|
||||
/* stride */ downsample_stride,
|
||||
/* dim */ dim,
|
||||
/* causal */ true,
|
||||
/* learnt */ true,
|
||||
vb.pp("upsample"),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
encoder_transformer,
|
||||
decoder_transformer,
|
||||
quantizer,
|
||||
downsample,
|
||||
upsample,
|
||||
config: cfg,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
self.encoder_transformer.reset_state();
|
||||
let xs = self.encoder_transformer.forward(&xs)?;
|
||||
let xs = &xs[0];
|
||||
xs.apply(&self.downsample)
|
||||
}
|
||||
|
||||
pub fn encode(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
self.encoder_transformer.reset_state();
|
||||
let xs = self.encoder_transformer.forward(&xs)?;
|
||||
let xs = &xs[0];
|
||||
let xs = xs.apply(&self.downsample)?;
|
||||
let codes = self.quantizer.encode(&xs)?;
|
||||
Ok(codes)
|
||||
}
|
||||
|
||||
pub fn encode_step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let xs = self.encoder.step(xs)?;
|
||||
let xs = self.encoder_transformer.step(&xs)?;
|
||||
let xs = self.downsample.step(&xs)?;
|
||||
match xs.as_option() {
|
||||
None => Ok(().into()),
|
||||
Some(xs) => {
|
||||
let codes = self.quantizer.encode(xs)?;
|
||||
Ok(codes.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(&mut self, codes: &Tensor) -> Result<Tensor> {
|
||||
let emb = self.quantizer.decode(codes)?;
|
||||
let emb = emb.apply(&self.upsample)?;
|
||||
self.decoder_transformer.reset_state();
|
||||
let outs = self.decoder_transformer.forward(&emb)?;
|
||||
let out = &outs[0];
|
||||
self.decoder.forward(out)
|
||||
}
|
||||
|
||||
pub fn decode_step(&mut self, codes: &StreamTensor) -> Result<StreamTensor> {
|
||||
let emb = match codes.as_option() {
|
||||
Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?),
|
||||
None => StreamTensor::empty(),
|
||||
};
|
||||
let emb = self.upsample.step(&emb)?;
|
||||
let out = self.decoder_transformer.step(&emb)?;
|
||||
self.decoder.step(&out)
|
||||
}
|
||||
|
||||
pub fn reset_state(&mut self) {
|
||||
self.encoder.reset_state();
|
||||
self.encoder_transformer.reset_state();
|
||||
self.decoder.reset_state();
|
||||
self.decoder_transformer.reset_state();
|
||||
self.upsample.reset_state();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load(model_file: &str, num_codebooks: Option<usize>, dev: &Device) -> Result<Encodec> {
|
||||
let vb =
|
||||
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
|
||||
let cfg = Config::v0_1(num_codebooks);
|
||||
let encodec = Encodec::new(cfg, vb)?;
|
||||
Ok(encodec)
|
||||
}
|
22
candle-transformers/src/models/mimi/mod.rs
Normal file
22
candle-transformers/src/models/mimi/mod.rs
Normal file
@ -0,0 +1,22 @@
|
||||
// Adapted from the reference implementation at:
|
||||
// https://github.com/kyutai-labs/moshi
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
pub use candle;
|
||||
pub use candle_nn;
|
||||
|
||||
pub mod conv;
|
||||
pub mod encodec;
|
||||
pub mod quantization;
|
||||
pub mod seanet;
|
||||
pub mod transformer;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum NormType {
|
||||
RmsNorm,
|
||||
LayerNorm,
|
||||
}
|
||||
|
||||
pub use encodec::{load, Config, Encodec as Model};
|
404
candle-transformers/src/models/mimi/quantization.rs
Normal file
404
candle-transformers/src/models/mimi/quantization.rs
Normal file
@ -0,0 +1,404 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use candle::{IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{linear, Linear, VarBuilder};
|
||||
|
||||
struct CodebookEncode;
|
||||
|
||||
impl candle::CustomOp2 for CodebookEncode {
|
||||
fn name(&self) -> &'static str {
|
||||
"cb"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
lhs_storage: &candle::CpuStorage,
|
||||
lhs_layout: &Layout,
|
||||
rhs_storage: &candle::CpuStorage,
|
||||
rhs_layout: &Layout,
|
||||
) -> Result<(candle::CpuStorage, Shape)> {
|
||||
use rayon::prelude::*;
|
||||
|
||||
let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
|
||||
let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
|
||||
if lhs_dim2 != rhs_dim2 {
|
||||
candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
|
||||
}
|
||||
if lhs_dim2 == 0 {
|
||||
candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
|
||||
}
|
||||
let lhs = match lhs_layout.contiguous_offsets() {
|
||||
None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
|
||||
Some((o1, o2)) => {
|
||||
let slice = lhs_storage.as_slice::<f32>()?;
|
||||
&slice[o1..o2]
|
||||
}
|
||||
};
|
||||
let rhs = match rhs_layout.contiguous_offsets() {
|
||||
None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
|
||||
Some((o1, o2)) => {
|
||||
let slice = rhs_storage.as_slice::<f32>()?;
|
||||
&slice[o1..o2]
|
||||
}
|
||||
};
|
||||
let dst = (0..lhs_dim1)
|
||||
.into_par_iter()
|
||||
.map(|idx1| {
|
||||
let mut where_min = 0;
|
||||
let mut min_dist = f32::INFINITY;
|
||||
let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
|
||||
for idx2 in 0..rhs_dim1 {
|
||||
let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
|
||||
let mut dist = 0f32;
|
||||
for (a, b) in lhs.iter().zip(rhs.iter()) {
|
||||
dist += (a - b) * (a - b)
|
||||
}
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
where_min = idx2;
|
||||
}
|
||||
}
|
||||
where_min as u32
|
||||
})
|
||||
.collect();
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (lhs_dim1,).into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EuclideanCodebook {
|
||||
initialized: Tensor,
|
||||
cluster_usage: Tensor,
|
||||
embedding_sum: Tensor,
|
||||
embedding: Tensor,
|
||||
c2: Tensor,
|
||||
epsilon: f64,
|
||||
dim: usize,
|
||||
span_encode: tracing::Span,
|
||||
span_decode: tracing::Span,
|
||||
}
|
||||
|
||||
impl EuclideanCodebook {
|
||||
pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let epsilon = 1e-5;
|
||||
let initialized = vb.get(1, "initialized")?;
|
||||
let cluster_usage = vb.get(codebook_size, "cluster_usage")?;
|
||||
let embedding_sum = vb.get((codebook_size, dim), "embed_sum")?;
|
||||
let embedding = {
|
||||
let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?;
|
||||
embedding_sum.broadcast_div(&cluster_usage)?
|
||||
};
|
||||
let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?;
|
||||
Ok(Self {
|
||||
initialized,
|
||||
cluster_usage,
|
||||
embedding_sum,
|
||||
embedding,
|
||||
c2,
|
||||
epsilon,
|
||||
dim,
|
||||
span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
|
||||
span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_very_slow(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_encode.enter();
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
// TODO: avoid repeating this.
|
||||
let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?;
|
||||
let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?;
|
||||
// Manual cdist implementation.
|
||||
let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?;
|
||||
let dists = diff.sqr()?.sum(D::Minus1)?;
|
||||
let codes = dists.argmin(D::Minus1)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_encode.enter();
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
let dot_prod = xs.matmul(&self.embedding.t()?)?;
|
||||
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_encode.enter();
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn decode(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_decode.enter();
|
||||
// let ys = candle_nn::Embedding::new(self.embedding.clone(), self.dim).forward(xs)?;
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.dim);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = self.embedding.index_select(&indexes, 0)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorQuantization {
|
||||
project_in: Option<Linear>,
|
||||
project_out: Option<Linear>,
|
||||
codebook: EuclideanCodebook,
|
||||
}
|
||||
|
||||
impl VectorQuantization {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
codebook_size: usize,
|
||||
codebook_dim: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let codebook_dim = codebook_dim.unwrap_or(dim);
|
||||
let (project_in, project_out) = if codebook_dim == dim {
|
||||
(None, None)
|
||||
} else {
|
||||
let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?;
|
||||
let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?;
|
||||
(Some(p_in), Some(p_out))
|
||||
};
|
||||
let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("codebook"))?;
|
||||
Ok(Self {
|
||||
project_in,
|
||||
project_out,
|
||||
codebook,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.t()?.apply(&self.project_in.as_ref())?;
|
||||
self.codebook.encode_slow(&xs)
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let quantized = self.codebook.decode(codes)?;
|
||||
let quantized = match &self.project_out {
|
||||
None => quantized,
|
||||
Some(p) => quantized.apply(p)?,
|
||||
};
|
||||
quantized.t()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResidualVectorQuantization {
|
||||
layers: Vec<VectorQuantization>,
|
||||
}
|
||||
|
||||
impl ResidualVectorQuantization {
|
||||
pub fn new(
|
||||
n_q: usize,
|
||||
dim: usize,
|
||||
codebook_size: usize,
|
||||
codebook_dim: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(n_q);
|
||||
for i in 0..n_q {
|
||||
let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut codes = Vec::with_capacity(self.layers.len());
|
||||
let mut residual = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let indices = layer.encode(&residual)?;
|
||||
let quantized = layer.decode(&indices)?;
|
||||
residual = (residual - quantized)?;
|
||||
codes.push(indices)
|
||||
}
|
||||
Tensor::stack(&codes, 0)
|
||||
}
|
||||
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
if self.layers.is_empty() {
|
||||
candle::bail!("empty layers in ResidualVectorQuantization")
|
||||
}
|
||||
if self.layers.len() != xs.dim(0)? {
|
||||
candle::bail!(
|
||||
"mismatch between the number of layers {} and the code shape {:?}",
|
||||
self.layers.len(),
|
||||
xs.shape()
|
||||
)
|
||||
}
|
||||
let mut quantized = self.layers[0].decode(&xs.i(0)?)?;
|
||||
for (i, layer) in self.layers.iter().enumerate().skip(1) {
|
||||
let xs = xs.i(i)?;
|
||||
quantized = (quantized + layer.decode(&xs))?
|
||||
}
|
||||
Ok(quantized)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResidualVectorQuantizer {
|
||||
vq: ResidualVectorQuantization,
|
||||
input_proj: Option<candle_nn::Conv1d>,
|
||||
output_proj: Option<candle_nn::Conv1d>,
|
||||
}
|
||||
|
||||
impl ResidualVectorQuantizer {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
input_dim: Option<usize>,
|
||||
output_dim: Option<usize>,
|
||||
n_q: usize,
|
||||
bins: usize,
|
||||
force_projection: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let input_dim = input_dim.unwrap_or(dim);
|
||||
let output_dim = output_dim.unwrap_or(dim);
|
||||
|
||||
let input_proj = if input_dim == dim && !force_projection {
|
||||
None
|
||||
} else {
|
||||
let c = candle_nn::conv1d_no_bias(
|
||||
input_dim,
|
||||
dim,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("input_proj"),
|
||||
)?;
|
||||
Some(c)
|
||||
};
|
||||
let output_proj = if output_dim == dim && !force_projection {
|
||||
None
|
||||
} else {
|
||||
let c = candle_nn::conv1d_no_bias(
|
||||
dim,
|
||||
output_dim,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("output_proj"),
|
||||
)?;
|
||||
Some(c)
|
||||
};
|
||||
|
||||
let vq = ResidualVectorQuantization::new(
|
||||
n_q, dim, /* codebook_size */ bins, /* codebook_dim */ None, vb,
|
||||
)?;
|
||||
Ok(Self {
|
||||
vq,
|
||||
input_proj,
|
||||
output_proj,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?;
|
||||
codes.transpose(0, 1)
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
// codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
|
||||
let codes = codes.transpose(0, 1)?;
|
||||
let quantized = self.vq.decode(&codes)?;
|
||||
match &self.output_proj {
|
||||
None => Ok(quantized),
|
||||
Some(p) => quantized.apply(p),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// we do not use any codebook_offset at the moment. When reconstructing the codes, we could just
|
||||
// concatenate the indexes.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SplitResidualVectorQuantizer {
|
||||
rvq_first: ResidualVectorQuantizer,
|
||||
rvq_rest: ResidualVectorQuantizer,
|
||||
n_q: usize,
|
||||
span_encode: tracing::Span,
|
||||
span_decode: tracing::Span,
|
||||
}
|
||||
|
||||
impl SplitResidualVectorQuantizer {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
input_dim: Option<usize>,
|
||||
output_dim: Option<usize>,
|
||||
n_q: usize,
|
||||
bins: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let rvq_first = ResidualVectorQuantizer::new(
|
||||
dim,
|
||||
input_dim,
|
||||
output_dim,
|
||||
1,
|
||||
bins,
|
||||
true,
|
||||
vb.pp("semantic_residual_vector_quantizer"),
|
||||
)?;
|
||||
let rvq_rest = ResidualVectorQuantizer::new(
|
||||
dim,
|
||||
input_dim,
|
||||
output_dim,
|
||||
n_q - 1,
|
||||
bins,
|
||||
true,
|
||||
vb.pp("acoustic_residual_vector_quantizer"),
|
||||
)?;
|
||||
let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode");
|
||||
let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode");
|
||||
Ok(Self {
|
||||
rvq_first,
|
||||
rvq_rest,
|
||||
n_q,
|
||||
span_encode,
|
||||
span_decode,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_encode.enter();
|
||||
let codes = self.rvq_first.encode(xs)?;
|
||||
if self.n_q > 1 {
|
||||
// We encode xs again here rather than the residual. The decomposition is not
|
||||
// hierarchical but rather having semantic tokens for rvq_first and the acoustic tokens
|
||||
// for rvq_rest.
|
||||
let rest_codes = self.rvq_rest.encode(xs)?;
|
||||
Tensor::cat(&[codes, rest_codes], 1)
|
||||
} else {
|
||||
Ok(codes)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
// codes is [B, K, T], with T frames, K nb of codebooks.
|
||||
let _enter = self.span_decode.enter();
|
||||
let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?;
|
||||
let quantized = if self.n_q > 1 {
|
||||
(quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))?
|
||||
} else {
|
||||
quantized
|
||||
};
|
||||
Ok(quantized)
|
||||
}
|
||||
}
|
465
candle-transformers/src/models/mimi/seanet.rs
Normal file
465
candle-transformers/src/models/mimi/seanet.rs
Normal file
@ -0,0 +1,465 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use candle::{streaming, Module, Result, StreamTensor, StreamingModule, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
use super::conv::{StreamableConv1d, StreamableConvTranspose1d};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub dimension: usize,
|
||||
pub channels: usize,
|
||||
pub causal: bool,
|
||||
pub n_filters: usize,
|
||||
pub n_residual_layers: usize,
|
||||
pub ratios: Vec<usize>,
|
||||
pub activation: candle_nn::Activation,
|
||||
pub norm: super::conv::Norm,
|
||||
pub kernel_size: usize,
|
||||
pub residual_kernel_size: usize,
|
||||
pub last_kernel_size: usize,
|
||||
pub dilation_base: usize,
|
||||
pub pad_mode: super::conv::PadMode,
|
||||
pub true_skip: bool,
|
||||
pub compress: usize,
|
||||
pub lstm: usize,
|
||||
pub disable_norm_outer_blocks: usize,
|
||||
pub final_activation: Option<candle_nn::Activation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SeaNetResnetBlock {
|
||||
block: Vec<StreamableConv1d>,
|
||||
shortcut: Option<StreamableConv1d>,
|
||||
activation: candle_nn::Activation,
|
||||
skip_op: candle::StreamingBinOp,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl SeaNetResnetBlock {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
k_sizes_and_dilations: &[(usize, usize)],
|
||||
activation: candle_nn::Activation,
|
||||
norm: Option<super::conv::Norm>,
|
||||
causal: bool,
|
||||
pad_mode: super::conv::PadMode,
|
||||
compress: usize,
|
||||
true_skip: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
|
||||
let hidden = dim / compress;
|
||||
let vb_b = vb.pp("block");
|
||||
for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
|
||||
let in_c = if i == 0 { dim } else { hidden };
|
||||
let out_c = if i == k_sizes_and_dilations.len() - 1 {
|
||||
dim
|
||||
} else {
|
||||
hidden
|
||||
};
|
||||
let c = StreamableConv1d::new(
|
||||
in_c,
|
||||
out_c,
|
||||
/* k_size */ *k_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ *dilation,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ causal,
|
||||
/* norm */ norm,
|
||||
/* pad_mode */ pad_mode,
|
||||
vb_b.pp(2 * i + 1),
|
||||
)?;
|
||||
block.push(c)
|
||||
}
|
||||
let shortcut = if true_skip {
|
||||
None
|
||||
} else {
|
||||
let c = StreamableConv1d::new(
|
||||
dim,
|
||||
dim,
|
||||
/* k_size */ 1,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ causal,
|
||||
/* norm */ norm,
|
||||
/* pad_mode */ pad_mode,
|
||||
vb.pp("shortcut"),
|
||||
)?;
|
||||
Some(c)
|
||||
};
|
||||
Ok(Self {
|
||||
block,
|
||||
shortcut,
|
||||
activation,
|
||||
skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
|
||||
span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SeaNetResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut ys = xs.clone();
|
||||
for block in self.block.iter() {
|
||||
ys = ys.apply(&self.activation)?.apply(block)?;
|
||||
}
|
||||
match self.shortcut.as_ref() {
|
||||
None => ys + xs,
|
||||
Some(shortcut) => ys + xs.apply(shortcut),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for SeaNetResnetBlock {
|
||||
fn reset_state(&mut self) {
|
||||
for block in self.block.iter_mut() {
|
||||
block.reset_state()
|
||||
}
|
||||
if let Some(shortcut) = self.shortcut.as_mut() {
|
||||
shortcut.reset_state()
|
||||
}
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut ys = xs.clone();
|
||||
for block in self.block.iter_mut() {
|
||||
ys = block.step(&ys.apply(&self.activation)?)?;
|
||||
}
|
||||
match self.shortcut.as_ref() {
|
||||
None => self.skip_op.step(&ys, xs),
|
||||
Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderLayer {
|
||||
residuals: Vec<SeaNetResnetBlock>,
|
||||
downsample: StreamableConv1d,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SeaNetEncoder {
|
||||
init_conv1d: StreamableConv1d,
|
||||
activation: candle_nn::Activation,
|
||||
layers: Vec<EncoderLayer>,
|
||||
final_conv1d: StreamableConv1d,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl SeaNetEncoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
if cfg.lstm > 0 {
|
||||
candle::bail!("seanet lstm is not supported")
|
||||
}
|
||||
let n_blocks = 2 + cfg.ratios.len();
|
||||
let mut mult = 1usize;
|
||||
let init_norm = if cfg.disable_norm_outer_blocks >= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let mut layer_idx = 0;
|
||||
let vb = vb.pp("layers");
|
||||
let init_conv1d = StreamableConv1d::new(
|
||||
cfg.channels,
|
||||
mult * cfg.n_filters,
|
||||
cfg.kernel_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ cfg.causal,
|
||||
/* norm */ init_norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx),
|
||||
)?;
|
||||
layer_idx += 1;
|
||||
let mut layers = Vec::with_capacity(cfg.ratios.len());
|
||||
|
||||
for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
|
||||
let norm = if cfg.disable_norm_outer_blocks >= i + 2 {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
|
||||
for j in 0..cfg.n_residual_layers {
|
||||
let resnet_block = SeaNetResnetBlock::new(
|
||||
mult * cfg.n_filters,
|
||||
&[
|
||||
(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
|
||||
(1, 1),
|
||||
],
|
||||
cfg.activation,
|
||||
norm,
|
||||
cfg.causal,
|
||||
cfg.pad_mode,
|
||||
cfg.compress,
|
||||
cfg.true_skip,
|
||||
vb.pp(layer_idx),
|
||||
)?;
|
||||
residuals.push(resnet_block);
|
||||
layer_idx += 1;
|
||||
}
|
||||
let downsample = StreamableConv1d::new(
|
||||
mult * cfg.n_filters,
|
||||
mult * cfg.n_filters * 2,
|
||||
/* k_size */ ratio * 2,
|
||||
/* stride */ ratio,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ true,
|
||||
/* norm */ norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx + 1),
|
||||
)?;
|
||||
layer_idx += 2;
|
||||
let layer = EncoderLayer {
|
||||
downsample,
|
||||
residuals,
|
||||
};
|
||||
layers.push(layer);
|
||||
mult *= 2
|
||||
}
|
||||
|
||||
let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let final_conv1d = StreamableConv1d::new(
|
||||
mult * cfg.n_filters,
|
||||
cfg.dimension,
|
||||
cfg.last_kernel_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ cfg.causal,
|
||||
/* norm */ final_norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx + 1),
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv1d,
|
||||
activation: cfg.activation,
|
||||
layers,
|
||||
final_conv1d,
|
||||
span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SeaNetEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.apply(&self.init_conv1d)?;
|
||||
for layer in self.layers.iter() {
|
||||
for residual in layer.residuals.iter() {
|
||||
xs = xs.apply(residual)?
|
||||
}
|
||||
xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
|
||||
}
|
||||
xs.apply(&self.activation)?.apply(&self.final_conv1d)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for SeaNetEncoder {
|
||||
fn reset_state(&mut self) {
|
||||
self.init_conv1d.reset_state();
|
||||
self.layers.iter_mut().for_each(|v| {
|
||||
v.residuals.iter_mut().for_each(|v| v.reset_state());
|
||||
v.downsample.reset_state()
|
||||
});
|
||||
self.final_conv1d.reset_state();
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.init_conv1d.step(xs)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
for residual in layer.residuals.iter_mut() {
|
||||
xs = residual.step(&xs)?;
|
||||
}
|
||||
xs = layer.downsample.step(&xs.apply(&self.activation)?)?;
|
||||
}
|
||||
self.final_conv1d.step(&xs.apply(&self.activation)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
upsample: StreamableConvTranspose1d,
|
||||
residuals: Vec<SeaNetResnetBlock>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SeaNetDecoder {
|
||||
init_conv1d: StreamableConv1d,
|
||||
activation: candle_nn::Activation,
|
||||
layers: Vec<DecoderLayer>,
|
||||
final_conv1d: StreamableConv1d,
|
||||
final_activation: Option<candle_nn::Activation>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl SeaNetDecoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
if cfg.lstm > 0 {
|
||||
candle::bail!("seanet lstm is not supported")
|
||||
}
|
||||
let n_blocks = 2 + cfg.ratios.len();
|
||||
let mut mult = 1 << cfg.ratios.len();
|
||||
let init_norm = if cfg.disable_norm_outer_blocks == n_blocks {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let mut layer_idx = 0;
|
||||
let vb = vb.pp("layers");
|
||||
let init_conv1d = StreamableConv1d::new(
|
||||
cfg.dimension,
|
||||
mult * cfg.n_filters,
|
||||
cfg.kernel_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ cfg.causal,
|
||||
/* norm */ init_norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx),
|
||||
)?;
|
||||
layer_idx += 1;
|
||||
let mut layers = Vec::with_capacity(cfg.ratios.len());
|
||||
for (i, &ratio) in cfg.ratios.iter().enumerate() {
|
||||
let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let upsample = StreamableConvTranspose1d::new(
|
||||
mult * cfg.n_filters,
|
||||
mult * cfg.n_filters / 2,
|
||||
/* k_size */ ratio * 2,
|
||||
/* stride */ ratio,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ true,
|
||||
/* norm */ norm,
|
||||
vb.pp(layer_idx + 1),
|
||||
)?;
|
||||
layer_idx += 2;
|
||||
|
||||
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
|
||||
for j in 0..cfg.n_residual_layers {
|
||||
let resnet_block = SeaNetResnetBlock::new(
|
||||
mult * cfg.n_filters / 2,
|
||||
&[
|
||||
(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
|
||||
(1, 1),
|
||||
],
|
||||
cfg.activation,
|
||||
norm,
|
||||
cfg.causal,
|
||||
cfg.pad_mode,
|
||||
cfg.compress,
|
||||
cfg.true_skip,
|
||||
vb.pp(layer_idx),
|
||||
)?;
|
||||
residuals.push(resnet_block);
|
||||
layer_idx += 1;
|
||||
}
|
||||
let layer = DecoderLayer {
|
||||
upsample,
|
||||
residuals,
|
||||
};
|
||||
layers.push(layer);
|
||||
mult /= 2
|
||||
}
|
||||
let final_norm = if cfg.disable_norm_outer_blocks >= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let final_conv1d = StreamableConv1d::new(
|
||||
cfg.n_filters,
|
||||
cfg.channels,
|
||||
cfg.last_kernel_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ cfg.causal,
|
||||
/* norm */ final_norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx + 1),
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv1d,
|
||||
activation: cfg.activation,
|
||||
layers,
|
||||
final_conv1d,
|
||||
final_activation: cfg.final_activation,
|
||||
span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SeaNetDecoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.apply(&self.init_conv1d)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
|
||||
for residual in layer.residuals.iter() {
|
||||
xs = xs.apply(residual)?
|
||||
}
|
||||
}
|
||||
let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
|
||||
let xs = match self.final_activation.as_ref() {
|
||||
None => xs,
|
||||
Some(act) => xs.apply(act)?,
|
||||
};
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for SeaNetDecoder {
|
||||
fn reset_state(&mut self) {
|
||||
self.init_conv1d.reset_state();
|
||||
self.layers.iter_mut().for_each(|v| {
|
||||
v.residuals.iter_mut().for_each(|v| v.reset_state());
|
||||
v.upsample.reset_state()
|
||||
});
|
||||
self.final_conv1d.reset_state();
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.init_conv1d.step(xs)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.upsample.step(&xs.apply(&self.activation)?)?;
|
||||
for residual in layer.residuals.iter_mut() {
|
||||
xs = residual.step(&xs)?;
|
||||
}
|
||||
}
|
||||
let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?;
|
||||
let xs = match self.final_activation.as_ref() {
|
||||
None => xs,
|
||||
Some(act) => xs.apply(act)?,
|
||||
};
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
802
candle-transformers/src/models/mimi/transformer.rs
Normal file
802
candle-transformers/src/models/mimi/transformer.rs
Normal file
@ -0,0 +1,802 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use candle::{DType, Device, IndexOp, Module, Result, StreamTensor, StreamingModule, Tensor, D};
|
||||
use candle_nn::{linear_no_bias, Linear, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn linear(in_d: usize, out_d: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_d, out_d, vb)
|
||||
} else {
|
||||
linear_no_bias(in_d, out_d, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum PositionalEmbedding {
|
||||
Rope,
|
||||
Sin,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub d_model: usize,
|
||||
pub num_heads: usize,
|
||||
pub num_layers: usize,
|
||||
pub causal: bool,
|
||||
pub norm_first: bool,
|
||||
pub bias_ff: bool,
|
||||
pub bias_attn: bool,
|
||||
pub layer_scale: Option<f64>,
|
||||
pub positional_embedding: PositionalEmbedding,
|
||||
pub use_conv_block: bool,
|
||||
pub cross_attention: bool,
|
||||
pub conv_kernel_size: usize,
|
||||
pub use_conv_bias: bool,
|
||||
pub gating: Option<candle_nn::Activation>,
|
||||
pub norm: super::NormType,
|
||||
pub context: usize,
|
||||
pub max_period: usize,
|
||||
pub max_seq_len: usize,
|
||||
|
||||
pub kv_repeat: usize,
|
||||
pub dim_feedforward: usize,
|
||||
pub conv_layout: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
pub fn new(dim: usize, max_seq_len: usize, theta: f32, dev: &Device) -> Result<Self> {
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
span: tracing::span!(tracing::Level::TRACE, "rot"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn apply_rotary_emb(&self, qk: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_b_size, _nheads, seqlen, _headdim) = qk.dims4()?;
|
||||
let qk_dtype = qk.dtype();
|
||||
let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
|
||||
candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &c, &s)?.to_dtype(qk_dtype)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LayerScale {
|
||||
scale: Tensor,
|
||||
}
|
||||
|
||||
impl LayerScale {
|
||||
pub fn new(d_model: usize, _init: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.get(d_model, "scale")?;
|
||||
Ok(Self { scale })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerScale {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.broadcast_mul(&self.scale)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_mask(
|
||||
size1: usize,
|
||||
size2: usize,
|
||||
context: usize,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size1)
|
||||
.flat_map(|i| {
|
||||
(0..size2)
|
||||
.map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size1, size2), device)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingMultiheadAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
kv_repeat: usize,
|
||||
num_heads: usize,
|
||||
context: usize,
|
||||
neg_inf: Tensor,
|
||||
rope: Option<Arc<RotaryEmbedding>>,
|
||||
kv_cache: candle_nn::kv_cache::KvCache,
|
||||
pos: usize,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamingMultiheadAttention {
|
||||
pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_dim = cfg.d_model;
|
||||
let num_kv = cfg.num_heads / cfg.kv_repeat;
|
||||
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
|
||||
let q_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("v_proj"))?;
|
||||
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("o_proj"))?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
rope: rope.clone(),
|
||||
kv_repeat: cfg.kv_repeat,
|
||||
num_heads: cfg.num_heads,
|
||||
context: cfg.context,
|
||||
neg_inf,
|
||||
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
|
||||
pos: 0,
|
||||
use_flash_attn: false,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
if self.kv_repeat != 1 {
|
||||
candle::bail!("only kv-repeat = 1 is supported")
|
||||
}
|
||||
let (b, t, hd) = xs.dims3()?;
|
||||
let head_dim = hd / self.num_heads;
|
||||
let q = xs
|
||||
.apply(&self.q_proj)?
|
||||
.reshape((b, t, self.num_heads, head_dim))?;
|
||||
let k = xs
|
||||
.apply(&self.k_proj)?
|
||||
.reshape((b, t, self.num_heads, head_dim))?;
|
||||
let v = xs
|
||||
.apply(&self.v_proj)?
|
||||
.reshape((b, t, self.num_heads, head_dim))?;
|
||||
// qk_layer_norm = None
|
||||
// kv_repeat = 1, otherwise we would need repeat_kv
|
||||
let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
|
||||
let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
|
||||
let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
|
||||
if let Some(rope) = &self.rope {
|
||||
q = rope.apply_rotary_emb(&q, self.pos)?;
|
||||
k = rope.apply_rotary_emb(&k, self.pos)?;
|
||||
}
|
||||
|
||||
let (k, v) = {
|
||||
self.pos += k.dim(2)?;
|
||||
self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?
|
||||
};
|
||||
// The KV cache keeps all the data at the moment, we want to trim
|
||||
// down the part that comes from the cache to at most context to
|
||||
// be coherent with the mask shape we provide.
|
||||
let k_len = k.dim(2)?;
|
||||
let k_target_len = t + usize::min(self.context, k_len - t);
|
||||
let (k, v) = if k_target_len < k_len {
|
||||
let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
|
||||
let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
|
||||
(k, v)
|
||||
} else {
|
||||
(k.clone(), v.clone())
|
||||
};
|
||||
|
||||
let xs = if q.dtype() == DType::BF16 && self.use_flash_attn {
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, softmax_scale, t > 1)?.transpose(1, 2)?
|
||||
} else {
|
||||
let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
|
||||
let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
|
||||
|
||||
let pre_ws = match mask {
|
||||
None => pre_ws,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_left((b, self.num_heads))?;
|
||||
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
|
||||
mask.where_cond(&neg_inf, &pre_ws)?
|
||||
}
|
||||
};
|
||||
|
||||
let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
|
||||
ws.matmul(&v)? // b,h,t,d
|
||||
};
|
||||
let xs = xs
|
||||
.transpose(1, 2)? // b,t,h,d
|
||||
.reshape((b, t, hd))?
|
||||
.apply(&self.out_proj)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache.reset()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
self.kv_cache = kv_cache
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingMultiheadCrossAttention {
|
||||
in_proj_q: Linear,
|
||||
in_proj_k: Linear,
|
||||
in_proj_v: Linear,
|
||||
out_proj: Linear,
|
||||
kv_repeat: usize,
|
||||
num_heads: usize,
|
||||
neg_inf: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamingMultiheadCrossAttention {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_dim = cfg.d_model;
|
||||
let num_kv = cfg.num_heads / cfg.kv_repeat;
|
||||
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
|
||||
let out_dim = embed_dim + 2 * kv_dim;
|
||||
let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
|
||||
let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?;
|
||||
let in_proj_weight_k = in_proj_weight.narrow(0, embed_dim, kv_dim)?;
|
||||
let in_proj_weight_v = in_proj_weight.narrow(0, embed_dim + kv_dim, kv_dim)?;
|
||||
let (in_proj_bias_q, in_proj_bias_k, in_proj_bias_v) = if cfg.bias_attn {
|
||||
let b = vb.get(out_dim, "in_proj_bias")?;
|
||||
let q = b.narrow(0, 0, embed_dim)?;
|
||||
let k = b.narrow(0, embed_dim, kv_dim)?;
|
||||
let v = b.narrow(0, embed_dim + kv_dim, kv_dim)?;
|
||||
(Some(q), Some(k), Some(v))
|
||||
} else {
|
||||
(None, None, None)
|
||||
};
|
||||
let in_proj_q = Linear::new(in_proj_weight_q, in_proj_bias_q);
|
||||
let in_proj_k = Linear::new(in_proj_weight_k, in_proj_bias_k);
|
||||
let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v);
|
||||
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
|
||||
Ok(Self {
|
||||
in_proj_q,
|
||||
in_proj_k,
|
||||
in_proj_v,
|
||||
out_proj,
|
||||
kv_repeat: cfg.kv_repeat,
|
||||
num_heads: cfg.num_heads,
|
||||
neg_inf,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mhca"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, ca_src: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
if self.kv_repeat != 1 {
|
||||
candle::bail!("only kv-repeat = 1 is supported")
|
||||
}
|
||||
let (b, t, hd) = xs.dims3()?;
|
||||
let head_dim = hd / self.num_heads;
|
||||
// time_dim = 1, layout: b,t,h,d
|
||||
let q = xs.apply(&self.in_proj_q)?;
|
||||
let k = ca_src.apply(&self.in_proj_k)?;
|
||||
let v = ca_src.apply(&self.in_proj_v)?;
|
||||
let (ca_b, ca_t, ca_dim) = k.dims3()?;
|
||||
let q = q.reshape((b, t, self.num_heads, head_dim))?;
|
||||
let k = k.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
|
||||
let v = v.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
|
||||
// qk_layer_norm = None
|
||||
// kv_repeat = 1, otherwise we would need repeat_kv
|
||||
let q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
|
||||
let k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
|
||||
let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
|
||||
|
||||
let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
|
||||
let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
|
||||
|
||||
let pre_ws = match mask {
|
||||
None => pre_ws,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_left((b, self.num_heads))?;
|
||||
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
|
||||
mask.where_cond(&neg_inf, &pre_ws)?
|
||||
}
|
||||
};
|
||||
|
||||
let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
|
||||
let xs = ws.matmul(&v)?; // b,h,t,d
|
||||
let xs = xs
|
||||
.transpose(1, 2)? // b,t,h,d
|
||||
.reshape((b, t, hd))?
|
||||
.apply(&self.out_proj)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Mlp {
|
||||
NoGating {
|
||||
span1: tracing::Span,
|
||||
linear1: Linear,
|
||||
span2: tracing::Span,
|
||||
linear2: Linear,
|
||||
span: tracing::Span,
|
||||
},
|
||||
Gating {
|
||||
linear_in: Linear,
|
||||
linear_out: Linear,
|
||||
activation: candle_nn::Activation,
|
||||
span: tracing::Span,
|
||||
},
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let d_model = cfg.d_model;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
|
||||
match cfg.gating {
|
||||
None => {
|
||||
let span1 = tracing::span!(tracing::Level::TRACE, "lin1");
|
||||
let span2 = tracing::span!(tracing::Level::TRACE, "lin2");
|
||||
let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("mlp.fc1"))?;
|
||||
let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("mlp.fc2"))?;
|
||||
Ok(Self::NoGating {
|
||||
linear1,
|
||||
linear2,
|
||||
span,
|
||||
span1,
|
||||
span2,
|
||||
})
|
||||
}
|
||||
Some(activation) => {
|
||||
let vb = vb.pp("gating");
|
||||
let hidden = if cfg.dim_feedforward == 4 * d_model {
|
||||
11 * d_model / 4
|
||||
} else {
|
||||
2 * cfg.dim_feedforward / 3
|
||||
};
|
||||
// TODO: Maybe use bias_ff here?
|
||||
let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?;
|
||||
let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?;
|
||||
Ok(Self::Gating {
|
||||
linear_in,
|
||||
linear_out,
|
||||
activation,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::NoGating {
|
||||
linear1,
|
||||
linear2,
|
||||
span,
|
||||
span1,
|
||||
span2,
|
||||
} => {
|
||||
let _enter = span.enter();
|
||||
let xs = {
|
||||
let _enter = span1.enter();
|
||||
xs.apply(linear1)?
|
||||
};
|
||||
let xs = xs.gelu_erf()?;
|
||||
{
|
||||
let _enter = span2.enter();
|
||||
xs.apply(linear2)
|
||||
}
|
||||
}
|
||||
Self::Gating {
|
||||
linear_in,
|
||||
linear_out,
|
||||
activation,
|
||||
span,
|
||||
} => {
|
||||
let _enter = span.enter();
|
||||
let xs = xs.apply(linear_in)?;
|
||||
let (b, t, _) = xs.dims3()?;
|
||||
let xs = xs.reshape((b, t, 2, ()))?;
|
||||
let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
|
||||
xs.apply(linear_out)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RmsNorm {
|
||||
pub(crate) alpha: Tensor,
|
||||
pub(crate) eps: f32,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result<Self> {
|
||||
let alpha = vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?;
|
||||
Ok(Self { alpha, eps })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Norm {
|
||||
LayerNorm(candle_nn::LayerNorm),
|
||||
RmsNorm(RmsNorm),
|
||||
}
|
||||
|
||||
impl Norm {
|
||||
pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let norm = match cfg.norm {
|
||||
super::NormType::LayerNorm => {
|
||||
let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?;
|
||||
Self::LayerNorm(norm)
|
||||
}
|
||||
super::NormType::RmsNorm => {
|
||||
let norm = RmsNorm::new(d_model, 1e-8, vb)?;
|
||||
Self::RmsNorm(norm)
|
||||
}
|
||||
};
|
||||
Ok(norm)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Norm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::LayerNorm(m) => m.forward(xs),
|
||||
Self::RmsNorm(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingTransformerLayer {
|
||||
self_attn: StreamingMultiheadAttention,
|
||||
mlp: Mlp,
|
||||
norm1: Norm,
|
||||
norm2: Norm,
|
||||
layer_scale_1: Option<LayerScale>,
|
||||
layer_scale_2: Option<LayerScale>,
|
||||
cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>,
|
||||
norm_first: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamingTransformerLayer {
|
||||
pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
if cfg.use_conv_block {
|
||||
candle::bail!("conv-block is not supported")
|
||||
}
|
||||
let d_model = cfg.d_model;
|
||||
let mlp = Mlp::new(cfg, vb.clone())?;
|
||||
let (norm1, norm2) = match cfg.norm {
|
||||
super::NormType::LayerNorm => {
|
||||
let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("input_layernorm"))?;
|
||||
let norm2 =
|
||||
candle_nn::layer_norm(d_model, 1e-5, vb.pp("post_attention_layernorm"))?;
|
||||
(Norm::LayerNorm(norm1), Norm::LayerNorm(norm2))
|
||||
}
|
||||
super::NormType::RmsNorm => {
|
||||
let norm1 = RmsNorm::new(d_model, 1e-8, vb.pp("input_rmsnorm"))?;
|
||||
let norm2 = RmsNorm::new(d_model, 1e-8, vb.pp("post_attention_rmsnorm"))?;
|
||||
(Norm::RmsNorm(norm1), Norm::RmsNorm(norm2))
|
||||
}
|
||||
};
|
||||
let layer_scale_1 = match cfg.layer_scale {
|
||||
None => None,
|
||||
Some(ls) => {
|
||||
let ls = LayerScale::new(d_model, ls, vb.pp("self_attn_layer_scale"))?;
|
||||
Some(ls)
|
||||
}
|
||||
};
|
||||
let layer_scale_2 = match cfg.layer_scale {
|
||||
None => None,
|
||||
Some(ls) => {
|
||||
let ls = LayerScale::new(d_model, ls, vb.pp("mlp_layer_scale"))?;
|
||||
Some(ls)
|
||||
}
|
||||
};
|
||||
let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
|
||||
let cross_attn = if cfg.cross_attention {
|
||||
let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
|
||||
let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
|
||||
Some((norm_cross, cross_attn))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
norm1,
|
||||
norm2,
|
||||
layer_scale_1,
|
||||
layer_scale_2,
|
||||
cross_attn,
|
||||
norm_first: cfg.norm_first,
|
||||
span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
ca_src: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
if !self.norm_first {
|
||||
candle::bail!("only norm_first = true is supported")
|
||||
}
|
||||
let norm1 = xs.apply(&self.norm1)?;
|
||||
let xs = (xs
|
||||
+ self
|
||||
.self_attn
|
||||
.forward(&norm1, mask)?
|
||||
.apply(&self.layer_scale_1.as_ref())?)?;
|
||||
|
||||
let xs = match (&self.cross_attn, ca_src) {
|
||||
(Some((norm_cross, cross_attn)), Some(ca_src)) => {
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(norm_cross)?;
|
||||
(residual + cross_attn.forward(&xs, ca_src, None)?)?
|
||||
}
|
||||
_ => xs,
|
||||
};
|
||||
|
||||
let xs = (&xs
|
||||
+ xs.apply(&self.norm2)?
|
||||
.apply(&self.mlp)?
|
||||
.apply(&self.layer_scale_2.as_ref()))?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
self.self_attn.set_kv_cache(kv_cache)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingTransformer {
|
||||
layers: Vec<StreamingTransformerLayer>,
|
||||
context: usize,
|
||||
positional_embedding: PositionalEmbedding,
|
||||
max_period: usize,
|
||||
}
|
||||
|
||||
impl StreamingTransformer {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_l = vb.pp("layers");
|
||||
let rope = match cfg.positional_embedding {
|
||||
PositionalEmbedding::Rope => {
|
||||
let rope = RotaryEmbedding::new(
|
||||
cfg.d_model / cfg.num_heads,
|
||||
cfg.max_seq_len,
|
||||
cfg.max_period as f32,
|
||||
vb.device(),
|
||||
)?;
|
||||
Some(Arc::new(rope))
|
||||
}
|
||||
PositionalEmbedding::Sin | PositionalEmbedding::None => None,
|
||||
};
|
||||
let mut layers = Vec::with_capacity(cfg.num_layers);
|
||||
for layer_idx in 0..cfg.num_layers {
|
||||
let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self {
|
||||
layers,
|
||||
context: cfg.context,
|
||||
positional_embedding: cfg.positional_embedding,
|
||||
max_period: cfg.max_period,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward_ca(xs, None)
|
||||
}
|
||||
|
||||
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (_b, t, c) = xs.dims3()?;
|
||||
// We will extract at most "context" from the kv_cache.
|
||||
// Note that the mask will discard the values that are before context.
|
||||
let pos = self.layers[0]
|
||||
.self_attn
|
||||
.kv_cache
|
||||
.k_cache()
|
||||
.current_seq_len()
|
||||
.min(self.context);
|
||||
let mask = if t == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(t, pos + t, self.context, xs.device())?)
|
||||
};
|
||||
let mut xs = match self.positional_embedding {
|
||||
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
||||
PositionalEmbedding::Sin => {
|
||||
let dev = xs.device();
|
||||
let theta = self.max_period as f32;
|
||||
let half_dim = c / 2;
|
||||
let positions = Tensor::arange(pos as u32, (pos + t) as u32, dev)?
|
||||
.unsqueeze(1)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let inv_freq: Vec<_> = (0..half_dim)
|
||||
.map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let freqs = positions.broadcast_mul(&inv_freq)?;
|
||||
let pos_emb =
|
||||
Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?.to_dtype(xs.dtype())?;
|
||||
xs.broadcast_add(&pos_emb)?
|
||||
}
|
||||
};
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, ca_src, mask.as_ref())?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn copy_state(&mut self, from: &Self) -> Result<()> {
|
||||
if self.layers.len() != from.layers.len() {
|
||||
candle::bail!("cannot copy kv-caches as the transformers have different depths")
|
||||
}
|
||||
self.layers
|
||||
.iter_mut()
|
||||
.zip(from.layers.iter())
|
||||
.for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for StreamingTransformer {
|
||||
fn reset_state(&mut self) {
|
||||
self.layers.iter_mut().for_each(|v| v.reset_kv_cache())
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
match xs.as_option() {
|
||||
None => Ok(StreamTensor::empty()),
|
||||
Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProjectedTransformer {
|
||||
transformer: StreamingTransformer,
|
||||
input_proj: Option<Linear>,
|
||||
output_projs: Vec<Option<Linear>>,
|
||||
conv_layout: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ProjectedTransformer {
|
||||
pub fn new(
|
||||
input_dim: usize,
|
||||
output_dims: &[usize],
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let transformer = StreamingTransformer::new(cfg, vb.clone())?;
|
||||
let input_proj = if input_dim == cfg.d_model {
|
||||
None
|
||||
} else {
|
||||
let l = linear_no_bias(input_dim, cfg.d_model, vb.pp("input_proj"))?;
|
||||
Some(l)
|
||||
};
|
||||
let mut output_projs = Vec::with_capacity(output_dims.len());
|
||||
let vb_o = vb.pp("output_projs");
|
||||
for (i, &output_dim) in output_dims.iter().enumerate() {
|
||||
let output_proj = if output_dim == cfg.d_model {
|
||||
None
|
||||
} else {
|
||||
let l = linear_no_bias(cfg.d_model, output_dim, vb_o.pp(i))?;
|
||||
Some(l)
|
||||
};
|
||||
output_projs.push(output_proj)
|
||||
}
|
||||
Ok(Self {
|
||||
transformer,
|
||||
input_proj,
|
||||
output_projs,
|
||||
conv_layout: cfg.conv_layout,
|
||||
span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = if self.conv_layout {
|
||||
xs.transpose(1, 2)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
let xs = xs.apply(&self.input_proj.as_ref())?;
|
||||
let xs = self.transformer.forward(&xs)?;
|
||||
let mut ys = Vec::with_capacity(self.output_projs.len());
|
||||
for output_proj in self.output_projs.iter() {
|
||||
let ys_ = xs.apply(&output_proj.as_ref())?;
|
||||
let ys_ = if self.conv_layout {
|
||||
ys_.transpose(1, 2)?
|
||||
} else {
|
||||
ys_
|
||||
};
|
||||
ys.push(ys_)
|
||||
}
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for ProjectedTransformer {
|
||||
fn reset_state(&mut self) {
|
||||
self.transformer.reset_state()
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let xs = xs.apply(&|x: &Tensor| {
|
||||
if self.conv_layout {
|
||||
x.transpose(1, 2)
|
||||
} else {
|
||||
Ok(x.clone())
|
||||
}
|
||||
})?;
|
||||
let xs = xs.apply(&self.input_proj.as_ref())?;
|
||||
let xs = self.transformer.step(&xs)?;
|
||||
let ys = xs.apply(&self.output_projs[0].as_ref())?;
|
||||
ys.apply(&|y: &Tensor| {
|
||||
if self.conv_layout {
|
||||
y.transpose(1, 2)
|
||||
} else {
|
||||
Ok(y.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
@ -33,6 +33,7 @@ pub mod llava;
|
||||
pub mod mamba;
|
||||
pub mod marian;
|
||||
pub mod metavoice;
|
||||
pub mod mimi;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
|
Reference in New Issue
Block a user