Add the mimi audio-tokenizer. (#2488)

* Add the mimi audio-tokenizer.

* Formatting tweaks.

* Add a full example.

* Use the transformers names.

* More renamings.

* Get encoding and decoding to work.

* Clippy fixes.
This commit is contained in:
Laurent Mazare
2024-09-20 14:31:20 -06:00
committed by GitHub
parent 382c6b51af
commit c58c5d5b01
12 changed files with 3027 additions and 0 deletions

3
.gitignore vendored
View File

@ -43,3 +43,6 @@ candle-wasm-examples/**/config*.json
__pycache__
out.safetensors
out.wav
bria.mp3
bria.safetensors
bria.wav

View File

@ -67,6 +67,7 @@ onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
@ -101,6 +102,10 @@ required-features = ["candle-datasets"]
name = "llama2-c"
required-features = ["candle-datasets"]
[[example]]
name = "mimi"
required-features = ["mimi"]
[[example]]
name = "encodec"
required-features = ["encodec"]

View File

@ -0,0 +1,20 @@
# candle-mimi
[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio
compression model using an encoder/decoder architecture with residual vector
quantization. The candle implementation supports streaming meaning that it's
possible to encode or decode a stream of audio tokens on the flight to provide
low latency interaction with an audio model.
## Running one example
Generating some audio tokens from an audio files.
```bash
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors
```
And decoding the audio tokens back into a sound file.
```bash
cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav
```

View File

@ -0,0 +1,275 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};
pub const SAMPLE_RATE: usize = 24_000;
pub(crate) struct AudioOutputData_ {
resampled_data: std::collections::VecDeque<f32>,
resampler: rubato::FastFixedIn<f32>,
output_buffer: Vec<f32>,
input_buffer: Vec<f32>,
input_len: usize,
}
impl AudioOutputData_ {
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
use rubato::Resampler;
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
let resampler = rubato::FastFixedIn::new(
resample_ratio,
f64::max(resample_ratio, 1.0),
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
Ok(Self {
resampled_data,
resampler,
input_buffer,
output_buffer,
input_len: 0,
})
}
pub fn reset(&mut self) {
use rubato::Resampler;
self.output_buffer.fill(0.);
self.input_buffer.fill(0.);
self.resampler.reset();
self.resampled_data.clear();
}
pub(crate) fn take_all(&mut self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.resampled_data.len());
while let Some(elem) = self.resampled_data.pop_back() {
data.push(elem);
}
data
}
pub(crate) fn is_empty(&self) -> bool {
self.resampled_data.is_empty()
}
// Assumes that the input buffer is large enough.
fn push_input_buffer(&mut self, samples: &[f32]) {
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
self.input_len += samples.len()
}
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
use rubato::Resampler;
let mut pos_in = 0;
loop {
let rem = self.input_buffer.len() - self.input_len;
let pos_end = usize::min(pos_in + rem, samples.len());
self.push_input_buffer(&samples[pos_in..pos_end]);
pos_in = pos_end;
if self.input_len < self.input_buffer.len() {
break;
}
let (_, out_len) = self.resampler.process_into_buffer(
&[&self.input_buffer],
&mut [&mut self.output_buffer],
None,
)?;
for &elem in self.output_buffer[..out_len].iter() {
self.resampled_data.push_front(elem)
}
self.input_len = 0;
}
Ok(())
}
}
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio output stream!");
let host = cpal::default_host();
let device = host
.default_output_device()
.context("no output device available")?;
let mut supported_configs_range = device.supported_output_configs()?;
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
// On macOS, it's commonly the case that there are only stereo outputs.
None => device
.supported_output_configs()?
.next()
.context("no audio output available")?,
Some(config_range) => config_range,
};
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
let channels = config.channels as usize;
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
SAMPLE_RATE,
config.sample_rate.0 as usize,
)?));
let ad = audio_data.clone();
let stream = device.build_output_stream(
&config,
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
data.fill(0.);
let mut ad = ad.lock().unwrap();
let mut last_elem = 0f32;
for (idx, elem) in data.iter_mut().enumerate() {
if idx % channels == 0 {
match ad.resampled_data.pop_back() {
None => break,
Some(v) => {
last_elem = v;
*elem = v
}
}
} else {
*elem = last_elem
}
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio input stream!");
let host = cpal::default_host();
let device = host
.default_input_device()
.context("no input device available")?;
let mut supported_configs_range = device.supported_input_configs()?;
let config_range = supported_configs_range
.find(|c| c.channels() == 1)
.context("no audio input available")?;
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
config.sample_rate.0 as usize,
SAMPLE_RATE,
)?));
let ad = audio_data.clone();
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut ad = ad.lock().unwrap();
if let Err(err) = ad.push_samples(data) {
eprintln!("error processing audio input {err:?}")
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@ -0,0 +1,131 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::mimi::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
mod audio_io;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some mimi tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some mimi tokens stored as safetensors.
out_file: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("kyutai/mimi".to_string())
.get("model.safetensors")?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let config = Config::v0_1(None);
let mut model = Model::new(config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
codes.get("codes").expect("no codes in input file").clone()
}
Action::AudioToCode | Action::AudioToAudio => {
let pcm = if args.in_file == "-" {
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
let (stream, input_audio) = audio_io::setup_input_stream()?;
let mut pcms = vec![];
let stdin = std::thread::spawn(|| {
let mut s = String::new();
std::io::stdin().read_line(&mut s)
});
while !stdin.is_finished() {
let input = input_audio.lock().unwrap().take_all();
if input.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
pcms.push(input)
}
drop(stream);
pcms.concat()
} else {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
} else {
pcm
}
};
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
};
println!("codes shape: {:?}", codes.shape());
match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = model.decode(&codes)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
if args.out_file == "-" {
let (stream, ad) = audio_io::setup_output_stream()?;
{
let mut ad = ad.lock().unwrap();
ad.push_samples(&pcm)?;
}
loop {
let ad = ad.lock().unwrap();
if ad.is_empty() {
break;
}
// That's very weird, calling thread::sleep here triggers the stream to stop
// playing (the callback doesn't seem to be called anymore).
// std::thread::sleep(std::time::Duration::from_millis(100));
}
drop(stream)
} else {
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
}
Ok(())
}

View File

@ -0,0 +1,670 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use candle::{Module, Result, StreamTensor, StreamingModule, Tensor, D};
use candle_nn::{Conv1d, VarBuilder};
#[allow(clippy::enum_variant_names)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Norm {
WeightNorm,
SpectralNorm,
TimeGroupNorm,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum PadMode {
Constant,
Reflect,
Replicate,
}
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
bias: bool,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = if vb.contains_tensor("weight") {
vb.get((out_c, in_c, kernel_size), "weight")?
} else {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
};
let bias = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
Ok(Conv1d::new(weight, bias, config))
}
#[derive(Debug, Clone)]
pub struct NormConv1d {
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
span: tracing::Span,
}
impl NormConv1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
causal: bool,
norm: Option<Norm>,
bias: bool,
cfg: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Self> {
let conv = match norm {
None | Some(Norm::TimeGroupNorm) => {
if bias {
candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
} else {
candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
}
}
Some(Norm::WeightNorm) => {
conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
}
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
};
let norm = match norm {
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
Some(Norm::TimeGroupNorm) => {
if causal {
candle::bail!("GroupNorm doesn't support causal evaluation.")
}
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(norm)
}
};
Ok(Self {
conv,
norm,
span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"),
})
}
}
impl Module for NormConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = xs.apply(&self.conv)?;
match self.norm.as_ref() {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
#[derive(Debug, Clone)]
pub struct NormConvTranspose1d {
ws: Tensor,
bs: Option<Tensor>,
k_size: usize,
stride: usize,
groups: usize,
norm: Option<candle_nn::GroupNorm>,
span: tracing::Span,
}
impl NormConvTranspose1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
causal: bool,
norm: Option<Norm>,
bias: bool,
stride: usize,
groups: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("conv");
let bs = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
let ws = match norm {
None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
Some(Norm::WeightNorm) => {
if vb.contains_tensor("weight") {
vb.get((in_c, out_c, k_size), "weight")?
} else {
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
}
}
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
};
let (ws, groups) = if groups == out_c && in_c == out_c {
let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
let ws = ws
.repeat((1, out_c, 1))?
.mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
(ws, 1)
} else {
(ws, groups)
};
let norm = match norm {
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
Some(Norm::TimeGroupNorm) => {
if causal {
candle::bail!("GroupNorm doesn't support causal evaluation.")
}
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(norm)
}
};
Ok(Self {
ws,
bs,
k_size,
stride,
groups,
norm,
span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
})
}
}
impl Module for NormConvTranspose1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
// conv-transpose1d seems to be broken on metal after enough iterations. Causing
// the following error:
// _status < MTLCommandBufferStatusCommitted >
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
// This is now fixed in candle.
let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
let xs = match &self.bs {
None => xs,
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?;
xs.broadcast_add(&bias)?
}
};
match self.norm.as_ref() {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
fn get_extra_padding_for_conv1d(
xs: &Tensor,
k_size: usize,
stride: usize,
padding_total: usize,
) -> Result<usize> {
let len = xs.dim(D::Minus1)?;
let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
let ideal_len =
((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
Ok(ideal_len.saturating_sub(len))
}
fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
match mode {
PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
}
}
fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
let len = xs.dim(D::Minus1)?;
if len < unpad_l + unpad_r {
candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
}
xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
}
#[derive(Debug, Clone)]
pub struct StreamableConv1d {
conv: NormConv1d,
causal: bool,
pad_mode: PadMode,
state_prev_xs: StreamTensor,
left_pad_applied: bool,
kernel_size: usize,
span: tracing::Span,
}
impl StreamableConv1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
stride: usize,
dilation: usize,
groups: usize,
bias: bool,
causal: bool,
norm: Option<Norm>,
pad_mode: PadMode,
vb: VarBuilder,
) -> Result<Self> {
let cfg = candle_nn::Conv1dConfig {
padding: 0,
stride,
dilation,
groups,
};
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
if k_size < stride {
candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
}
Ok(Self {
conv,
causal,
pad_mode,
state_prev_xs: StreamTensor::empty(),
left_pad_applied: false,
kernel_size: k_size,
span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
})
}
}
impl Module for StreamableConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_b, _t, _c) = xs.dims3()?;
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
let conv_cfg = self.conv.conv.config();
// Effective kernel size with dilations.
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
let padding_total = k_size - conv_cfg.stride;
let extra_padding =
get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
let xs = if self.causal {
pad1d(xs, padding_total, extra_padding, self.pad_mode)?
} else {
let padding_right = padding_total / 2;
let padding_left = padding_total - padding_right;
pad1d(
xs,
padding_left,
padding_right + extra_padding,
self.pad_mode,
)?
};
xs.apply(&self.conv)
}
}
impl StreamingModule for StreamableConv1d {
fn reset_state(&mut self) {
self.state_prev_xs.reset();
self.left_pad_applied = false;
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let xs = match xs.as_option() {
None => return Ok(().into()),
Some(xs) => xs.clone(),
};
let xs = if self.left_pad_applied {
xs
} else {
self.left_pad_applied = true;
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
let conv_cfg = self.conv.conv.config();
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
let padding_total = k_size - conv_cfg.stride;
pad1d(&xs, padding_total, 0, self.pad_mode)?
};
let cfg = self.conv.conv.config();
let stride = cfg.stride;
let dilation = cfg.dilation;
let kernel = (self.kernel_size - 1) * dilation + 1;
let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
let seq_len = xs.seq_len(D::Minus1)?;
let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
if num_frames > 0 {
let offset = num_frames * stride;
self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
let in_l = (num_frames - 1) * stride + kernel;
let xs = xs.narrow(D::Minus1, 0, in_l)?;
// We apply the underlying convtr directly rather than through forward so as
// not to apply any padding here.
xs.apply(&self.conv.conv)
} else {
self.state_prev_xs = xs;
Ok(StreamTensor::empty())
}
}
}
#[derive(Debug, Clone)]
pub struct StreamableConvTranspose1d {
convtr: NormConvTranspose1d,
causal: bool,
state_prev_ys: StreamTensor,
kernel_size: usize,
span: tracing::Span,
}
impl StreamableConvTranspose1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
stride: usize,
groups: usize,
bias: bool,
causal: bool,
norm: Option<Norm>,
vb: VarBuilder,
) -> Result<Self> {
let convtr =
NormConvTranspose1d::new(in_c, out_c, k_size, causal, norm, bias, stride, groups, vb)?;
Ok(Self {
convtr,
causal,
kernel_size: k_size,
state_prev_ys: StreamTensor::empty(),
span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
})
}
}
impl Module for StreamableConvTranspose1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let k_size = self.convtr.k_size;
let stride = self.convtr.stride;
let padding_total = k_size.saturating_sub(stride);
let xs = xs.apply(&self.convtr)?;
if self.causal {
// This corresponds to trim_right_ratio = 1.
unpad1d(&xs, 0, padding_total)
} else {
let padding_right = padding_total / 2;
let padding_left = padding_total - padding_right;
unpad1d(&xs, padding_left, padding_right)
}
}
}
impl StreamingModule for StreamableConvTranspose1d {
fn reset_state(&mut self) {
self.state_prev_ys.reset()
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let xs = match xs.as_option() {
Some(xs) => xs,
None => return Ok(StreamTensor::empty()),
};
let stride = self.convtr.stride;
// We apply the underlying convtr directly rather than through forward so as
// not to apply any padding here.
let ys = self.convtr.forward(xs)?;
let ot = ys.dim(D::Minus1)?;
let ys = match self.state_prev_ys.as_option() {
None => ys,
Some(prev_ys) => {
let pt = prev_ys.dim(D::Minus1)?;
// Remove the bias as it will be applied multiple times.
let prev_ys = match &self.convtr.bs {
None => prev_ys.clone(),
Some(bias) => {
let bias = bias.reshape((1, (), 1))?;
prev_ys.broadcast_sub(&bias)?
}
};
let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
Tensor::cat(&[ys1, ys2], D::Minus1)?
}
};
let invalid_steps = self.kernel_size - stride;
let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
self.state_prev_ys = prev_ys;
Ok(ys)
}
}
#[derive(Debug, Clone)]
pub struct ConvDownsample1d {
conv: StreamableConv1d,
}
impl ConvDownsample1d {
pub fn new(
stride: usize,
dim: usize,
causal: bool,
learnt: bool,
vb: VarBuilder,
) -> Result<Self> {
if !learnt {
candle::bail!("only learnt=true is supported")
}
let conv = StreamableConv1d::new(
/* in_c */ dim,
/* out_c */ dim,
/* k_size_c */ 2 * stride,
/* stride */ stride,
/* dilation */ 1,
/* groups */ 1, // channel_wise = false
/* bias */ false,
/* causal */ causal,
/* norm */ None,
/* pad_mode */ PadMode::Replicate,
vb,
)?;
Ok(Self { conv })
}
}
impl Module for ConvDownsample1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.conv)
}
}
impl StreamingModule for ConvDownsample1d {
fn reset_state(&mut self) {
self.conv.reset_state()
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
self.conv.step(xs)
}
}
#[derive(Debug, Clone)]
pub struct ConvTrUpsample1d {
convtr: StreamableConvTranspose1d,
}
impl ConvTrUpsample1d {
pub fn new(
stride: usize,
dim: usize,
causal: bool,
learnt: bool,
vb: VarBuilder,
) -> Result<Self> {
if !learnt {
candle::bail!("only learnt=true is supported")
}
let convtr = StreamableConvTranspose1d::new(
dim,
dim,
/* k_size */ 2 * stride,
/* stride */ stride,
/* groups */ dim,
/* bias */ false,
/* causal */ causal,
/* norm */ None,
vb,
)?;
Ok(Self { convtr })
}
}
impl Module for ConvTrUpsample1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.convtr)
}
}
impl StreamingModule for ConvTrUpsample1d {
fn reset_state(&mut self) {
self.convtr.reset_state()
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
self.convtr.step(xs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle::IndexOp;
fn run_conv1d(
k_size: usize,
stride: usize,
dilation: usize,
step_size: usize,
len: usize,
bias: bool,
) -> Result<()> {
// TODO: We should ensure for the seed to be constant when running these tests.
let dev = &candle::Device::Cpu;
let vm = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
let conv1d = StreamableConv1d::new(
/* in_c */ 2,
/* out_c */ 3,
/* k_size */ k_size,
/* stride */ stride,
/* dilation */ dilation,
/* groups */ 1,
/* bias */ bias,
/* causal */ true,
/* norm */ None,
/* pad_mode */ PadMode::Constant,
vb,
)?;
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
let ys = conv1d.forward(&xs)?;
let mut conv1d = conv1d;
let mut ys_steps = vec![];
for idx in 0..len {
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
let ys = conv1d.step(&xs.into())?;
if let Some(ys) = ys.as_option() {
ys_steps.push(ys.clone())
}
}
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
let diff = (&ys - &ys_steps)?
.abs()?
.flatten_all()?
.max(0)?
.to_vec0::<f32>()?;
if diff > 1e-5 {
println!("{xs}");
println!("{ys}");
println!("{ys_steps}");
candle::bail!("larger diff than expected {diff}")
}
Ok(())
}
fn run_conv_tr1d(
k_size: usize,
stride: usize,
step_size: usize,
len: usize,
bias: bool,
) -> Result<()> {
// TODO: We should ensure for the seed to be constant when running these tests.
let dev = &candle::Device::Cpu;
let vm = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
let conv1d = StreamableConvTranspose1d::new(
/* in_c */ 2, /* out_c */ 3, /* k_size */ k_size,
/* stride */ stride, /* groups */ 1, /* bias */ bias,
/* causal */ true, /* norm */ None, vb,
)?;
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
let ys = conv1d.forward(&xs)?;
let mut conv1d = conv1d;
let mut ys_steps = vec![];
for idx in 0..len {
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
let ys = conv1d.step(&xs.into())?;
if let Some(ys) = ys.as_option() {
ys_steps.push(ys.clone())
}
}
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
let diff = (&ys - &ys_steps)?
.abs()?
.flatten_all()?
.max(0)?
.to_vec0::<f32>()?;
if diff > 1e-5 {
println!("{xs}");
println!("{ys}");
println!("{ys_steps}");
candle::bail!("larger diff than expected {diff}")
}
Ok(())
}
#[test]
fn conv1d() -> Result<()> {
for step_size in [1, 2, 3] {
for bias in [false, true] {
run_conv1d(1, 1, 1, step_size, 5, bias)?;
run_conv1d(2, 1, 1, step_size, 5, bias)?;
run_conv1d(2, 2, 1, step_size, 6, bias)?;
run_conv1d(3, 2, 1, step_size, 8, bias)?;
run_conv1d(3, 2, 2, step_size, 8, bias)?;
}
}
Ok(())
}
#[test]
fn conv_tr1d() -> Result<()> {
for step_size in [1, 2, 3] {
for bias in [false, true] {
run_conv_tr1d(1, 1, step_size, 5, bias)?;
run_conv_tr1d(2, 1, step_size, 5, bias)?;
run_conv_tr1d(3, 1, step_size, 5, bias)?;
run_conv_tr1d(3, 2, step_size, 5, bias)?;
}
}
Ok(())
}
}

View File

@ -0,0 +1,229 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use super::{conv, quantization, seanet, transformer};
use candle::{DType, Device, Module, Result, StreamTensor, StreamingModule, Tensor};
use candle_nn::VarBuilder;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum ResampleMethod {
Conv,
Interpolate,
}
#[derive(Debug, Clone)]
pub struct Config {
pub channels: usize,
pub sample_rate: f64,
pub frame_rate: f64,
pub renormalize: bool,
pub resample_method: ResampleMethod,
pub seanet: seanet::Config,
pub transformer: transformer::Config,
pub quantizer_n_q: usize,
pub quantizer_bins: usize,
pub quantizer_dim: usize,
}
impl Config {
// /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/b7d2bd5a/.hydra/config.yaml
pub fn v0_1(num_codebooks: Option<usize>) -> Self {
let seanet_cfg = seanet::Config {
dimension: 512,
channels: 1,
causal: true,
n_filters: 64,
n_residual_layers: 1,
activation: candle_nn::Activation::Elu(1.),
compress: 2,
dilation_base: 2,
disable_norm_outer_blocks: 0,
final_activation: None,
kernel_size: 7,
residual_kernel_size: 3,
last_kernel_size: 3,
lstm: 0,
norm: conv::Norm::WeightNorm,
pad_mode: conv::PadMode::Constant,
ratios: vec![8, 6, 5, 4],
true_skip: true,
};
let transformer_cfg = transformer::Config {
d_model: seanet_cfg.dimension,
num_heads: 8,
num_layers: 8,
causal: true,
norm_first: true,
bias_ff: false,
bias_attn: false,
layer_scale: Some(0.01),
context: 250,
conv_kernel_size: 5,
use_conv_bias: true,
use_conv_block: false,
cross_attention: false,
max_period: 10000,
gating: None,
norm: super::NormType::LayerNorm,
positional_embedding: transformer::PositionalEmbedding::Rope,
dim_feedforward: 2048,
kv_repeat: 1,
conv_layout: true, // see builders.py
max_seq_len: 8192, // the transformer works at 25hz so this is ~5 mins.
};
Config {
channels: 1,
sample_rate: 24_000.,
frame_rate: 12.5,
renormalize: true,
resample_method: ResampleMethod::Conv,
seanet: seanet_cfg,
transformer: transformer_cfg,
quantizer_n_q: num_codebooks.unwrap_or(16),
quantizer_bins: 2048,
quantizer_dim: 256,
}
}
}
#[derive(Debug, Clone)]
pub struct Encodec {
encoder: seanet::SeaNetEncoder,
decoder: seanet::SeaNetDecoder,
encoder_transformer: transformer::ProjectedTransformer,
decoder_transformer: transformer::ProjectedTransformer,
downsample: conv::ConvDownsample1d,
upsample: conv::ConvTrUpsample1d,
quantizer: quantization::SplitResidualVectorQuantizer,
config: Config,
}
impl Encodec {
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
let dim = cfg.seanet.dimension;
let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?;
let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?;
let encoder_transformer = transformer::ProjectedTransformer::new(
dim,
&[dim],
&cfg.transformer,
vb.pp("encoder_transformer"),
)?;
let decoder_transformer = transformer::ProjectedTransformer::new(
dim,
&[dim],
&cfg.transformer,
vb.pp("decoder_transformer"),
)?;
let quantizer = quantization::SplitResidualVectorQuantizer::new(
/* dim */ cfg.quantizer_dim,
/* input_dim */ Some(dim),
/* output_dim */ Some(dim),
/* n_q */ cfg.quantizer_n_q,
/* bins */ cfg.quantizer_bins,
vb.pp("quantizer"),
)?;
let encoder_frame_rate =
cfg.sample_rate / cfg.seanet.ratios.iter().product::<usize>() as f64;
let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize;
// `upsample` and `downsample` only apply if frame_rate is different from encoder_frame_rate.
let downsample = conv::ConvDownsample1d::new(
/* stride */ downsample_stride,
/* dim */ dim,
/* causal */ true,
/* learnt */ true,
vb.pp("downsample"),
)?;
let upsample = conv::ConvTrUpsample1d::new(
/* stride */ downsample_stride,
/* dim */ dim,
/* causal */ true,
/* learnt */ true,
vb.pp("upsample"),
)?;
Ok(Self {
encoder,
decoder,
encoder_transformer,
decoder_transformer,
quantizer,
downsample,
upsample,
config: cfg,
})
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result<Tensor> {
let xs = self.encoder.forward(xs)?;
self.encoder_transformer.reset_state();
let xs = self.encoder_transformer.forward(&xs)?;
let xs = &xs[0];
xs.apply(&self.downsample)
}
pub fn encode(&mut self, xs: &Tensor) -> Result<Tensor> {
let xs = self.encoder.forward(xs)?;
self.encoder_transformer.reset_state();
let xs = self.encoder_transformer.forward(&xs)?;
let xs = &xs[0];
let xs = xs.apply(&self.downsample)?;
let codes = self.quantizer.encode(&xs)?;
Ok(codes)
}
pub fn encode_step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let xs = self.encoder.step(xs)?;
let xs = self.encoder_transformer.step(&xs)?;
let xs = self.downsample.step(&xs)?;
match xs.as_option() {
None => Ok(().into()),
Some(xs) => {
let codes = self.quantizer.encode(xs)?;
Ok(codes.into())
}
}
}
pub fn decode(&mut self, codes: &Tensor) -> Result<Tensor> {
let emb = self.quantizer.decode(codes)?;
let emb = emb.apply(&self.upsample)?;
self.decoder_transformer.reset_state();
let outs = self.decoder_transformer.forward(&emb)?;
let out = &outs[0];
self.decoder.forward(out)
}
pub fn decode_step(&mut self, codes: &StreamTensor) -> Result<StreamTensor> {
let emb = match codes.as_option() {
Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?),
None => StreamTensor::empty(),
};
let emb = self.upsample.step(&emb)?;
let out = self.decoder_transformer.step(&emb)?;
self.decoder.step(&out)
}
pub fn reset_state(&mut self) {
self.encoder.reset_state();
self.encoder_transformer.reset_state();
self.decoder.reset_state();
self.decoder_transformer.reset_state();
self.upsample.reset_state();
}
}
pub fn load(model_file: &str, num_codebooks: Option<usize>, dev: &Device) -> Result<Encodec> {
let vb =
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
let cfg = Config::v0_1(num_codebooks);
let encodec = Encodec::new(cfg, vb)?;
Ok(encodec)
}

View File

@ -0,0 +1,22 @@
// Adapted from the reference implementation at:
// https://github.com/kyutai-labs/moshi
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
pub use candle;
pub use candle_nn;
pub mod conv;
pub mod encodec;
pub mod quantization;
pub mod seanet;
pub mod transformer;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum NormType {
RmsNorm,
LayerNorm,
}
pub use encodec::{load, Config, Encodec as Model};

View File

@ -0,0 +1,404 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use candle::{IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{linear, Linear, VarBuilder};
struct CodebookEncode;
impl candle::CustomOp2 for CodebookEncode {
fn name(&self) -> &'static str {
"cb"
}
fn cpu_fwd(
&self,
lhs_storage: &candle::CpuStorage,
lhs_layout: &Layout,
rhs_storage: &candle::CpuStorage,
rhs_layout: &Layout,
) -> Result<(candle::CpuStorage, Shape)> {
use rayon::prelude::*;
let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
if lhs_dim2 != rhs_dim2 {
candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
}
if lhs_dim2 == 0 {
candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
}
let lhs = match lhs_layout.contiguous_offsets() {
None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
Some((o1, o2)) => {
let slice = lhs_storage.as_slice::<f32>()?;
&slice[o1..o2]
}
};
let rhs = match rhs_layout.contiguous_offsets() {
None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
Some((o1, o2)) => {
let slice = rhs_storage.as_slice::<f32>()?;
&slice[o1..o2]
}
};
let dst = (0..lhs_dim1)
.into_par_iter()
.map(|idx1| {
let mut where_min = 0;
let mut min_dist = f32::INFINITY;
let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
for idx2 in 0..rhs_dim1 {
let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
let mut dist = 0f32;
for (a, b) in lhs.iter().zip(rhs.iter()) {
dist += (a - b) * (a - b)
}
if dist < min_dist {
min_dist = dist;
where_min = idx2;
}
}
where_min as u32
})
.collect();
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (lhs_dim1,).into()))
}
}
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct EuclideanCodebook {
initialized: Tensor,
cluster_usage: Tensor,
embedding_sum: Tensor,
embedding: Tensor,
c2: Tensor,
epsilon: f64,
dim: usize,
span_encode: tracing::Span,
span_decode: tracing::Span,
}
impl EuclideanCodebook {
pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result<Self> {
let epsilon = 1e-5;
let initialized = vb.get(1, "initialized")?;
let cluster_usage = vb.get(codebook_size, "cluster_usage")?;
let embedding_sum = vb.get((codebook_size, dim), "embed_sum")?;
let embedding = {
let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?;
embedding_sum.broadcast_div(&cluster_usage)?
};
let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?;
Ok(Self {
initialized,
cluster_usage,
embedding_sum,
embedding,
c2,
epsilon,
dim,
span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
})
}
pub fn encode_very_slow(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span_encode.enter();
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
// TODO: avoid repeating this.
let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?;
let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?;
// Manual cdist implementation.
let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?;
let dists = diff.sqr()?.sum(D::Minus1)?;
let codes = dists.argmin(D::Minus1)?;
codes.reshape(target_shape)
}
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span_encode.enter();
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let dot_prod = xs.matmul(&self.embedding.t()?)?;
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
codes.reshape(target_shape)
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span_encode.enter();
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?;
codes.reshape(target_shape)
}
pub fn decode(&self, indexes: &Tensor) -> Result<Tensor> {
let _enter = self.span_decode.enter();
// let ys = candle_nn::Embedding::new(self.embedding.clone(), self.dim).forward(xs)?;
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.dim);
let indexes = indexes.flatten_all()?;
let values = self.embedding.index_select(&indexes, 0)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
}
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct VectorQuantization {
project_in: Option<Linear>,
project_out: Option<Linear>,
codebook: EuclideanCodebook,
}
impl VectorQuantization {
pub fn new(
dim: usize,
codebook_size: usize,
codebook_dim: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let codebook_dim = codebook_dim.unwrap_or(dim);
let (project_in, project_out) = if codebook_dim == dim {
(None, None)
} else {
let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?;
let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?;
(Some(p_in), Some(p_out))
};
let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("codebook"))?;
Ok(Self {
project_in,
project_out,
codebook,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.t()?.apply(&self.project_in.as_ref())?;
self.codebook.encode_slow(&xs)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let quantized = self.codebook.decode(codes)?;
let quantized = match &self.project_out {
None => quantized,
Some(p) => quantized.apply(p)?,
};
quantized.t()
}
}
#[derive(Debug, Clone)]
pub struct ResidualVectorQuantization {
layers: Vec<VectorQuantization>,
}
impl ResidualVectorQuantization {
pub fn new(
n_q: usize,
dim: usize,
codebook_size: usize,
codebook_dim: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("layers");
let mut layers = Vec::with_capacity(n_q);
for i in 0..n_q {
let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?;
layers.push(layer)
}
Ok(Self { layers })
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut codes = Vec::with_capacity(self.layers.len());
let mut residual = xs.clone();
for layer in self.layers.iter() {
let indices = layer.encode(&residual)?;
let quantized = layer.decode(&indices)?;
residual = (residual - quantized)?;
codes.push(indices)
}
Tensor::stack(&codes, 0)
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
if self.layers.is_empty() {
candle::bail!("empty layers in ResidualVectorQuantization")
}
if self.layers.len() != xs.dim(0)? {
candle::bail!(
"mismatch between the number of layers {} and the code shape {:?}",
self.layers.len(),
xs.shape()
)
}
let mut quantized = self.layers[0].decode(&xs.i(0)?)?;
for (i, layer) in self.layers.iter().enumerate().skip(1) {
let xs = xs.i(i)?;
quantized = (quantized + layer.decode(&xs))?
}
Ok(quantized)
}
}
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct ResidualVectorQuantizer {
vq: ResidualVectorQuantization,
input_proj: Option<candle_nn::Conv1d>,
output_proj: Option<candle_nn::Conv1d>,
}
impl ResidualVectorQuantizer {
pub fn new(
dim: usize,
input_dim: Option<usize>,
output_dim: Option<usize>,
n_q: usize,
bins: usize,
force_projection: bool,
vb: VarBuilder,
) -> Result<Self> {
let input_dim = input_dim.unwrap_or(dim);
let output_dim = output_dim.unwrap_or(dim);
let input_proj = if input_dim == dim && !force_projection {
None
} else {
let c = candle_nn::conv1d_no_bias(
input_dim,
dim,
1,
Default::default(),
vb.pp("input_proj"),
)?;
Some(c)
};
let output_proj = if output_dim == dim && !force_projection {
None
} else {
let c = candle_nn::conv1d_no_bias(
dim,
output_dim,
1,
Default::default(),
vb.pp("output_proj"),
)?;
Some(c)
};
let vq = ResidualVectorQuantization::new(
n_q, dim, /* codebook_size */ bins, /* codebook_dim */ None, vb,
)?;
Ok(Self {
vq,
input_proj,
output_proj,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?;
codes.transpose(0, 1)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
// codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
let codes = codes.transpose(0, 1)?;
let quantized = self.vq.decode(&codes)?;
match &self.output_proj {
None => Ok(quantized),
Some(p) => quantized.apply(p),
}
}
}
// we do not use any codebook_offset at the moment. When reconstructing the codes, we could just
// concatenate the indexes.
#[derive(Debug, Clone)]
pub struct SplitResidualVectorQuantizer {
rvq_first: ResidualVectorQuantizer,
rvq_rest: ResidualVectorQuantizer,
n_q: usize,
span_encode: tracing::Span,
span_decode: tracing::Span,
}
impl SplitResidualVectorQuantizer {
pub fn new(
dim: usize,
input_dim: Option<usize>,
output_dim: Option<usize>,
n_q: usize,
bins: usize,
vb: VarBuilder,
) -> Result<Self> {
let rvq_first = ResidualVectorQuantizer::new(
dim,
input_dim,
output_dim,
1,
bins,
true,
vb.pp("semantic_residual_vector_quantizer"),
)?;
let rvq_rest = ResidualVectorQuantizer::new(
dim,
input_dim,
output_dim,
n_q - 1,
bins,
true,
vb.pp("acoustic_residual_vector_quantizer"),
)?;
let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode");
let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode");
Ok(Self {
rvq_first,
rvq_rest,
n_q,
span_encode,
span_decode,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span_encode.enter();
let codes = self.rvq_first.encode(xs)?;
if self.n_q > 1 {
// We encode xs again here rather than the residual. The decomposition is not
// hierarchical but rather having semantic tokens for rvq_first and the acoustic tokens
// for rvq_rest.
let rest_codes = self.rvq_rest.encode(xs)?;
Tensor::cat(&[codes, rest_codes], 1)
} else {
Ok(codes)
}
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
// codes is [B, K, T], with T frames, K nb of codebooks.
let _enter = self.span_decode.enter();
let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?;
let quantized = if self.n_q > 1 {
(quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))?
} else {
quantized
};
Ok(quantized)
}
}

View File

@ -0,0 +1,465 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use candle::{streaming, Module, Result, StreamTensor, StreamingModule, Tensor};
use candle_nn::VarBuilder;
use super::conv::{StreamableConv1d, StreamableConvTranspose1d};
#[derive(Debug, Clone)]
pub struct Config {
pub dimension: usize,
pub channels: usize,
pub causal: bool,
pub n_filters: usize,
pub n_residual_layers: usize,
pub ratios: Vec<usize>,
pub activation: candle_nn::Activation,
pub norm: super::conv::Norm,
pub kernel_size: usize,
pub residual_kernel_size: usize,
pub last_kernel_size: usize,
pub dilation_base: usize,
pub pad_mode: super::conv::PadMode,
pub true_skip: bool,
pub compress: usize,
pub lstm: usize,
pub disable_norm_outer_blocks: usize,
pub final_activation: Option<candle_nn::Activation>,
}
#[derive(Debug, Clone)]
pub struct SeaNetResnetBlock {
block: Vec<StreamableConv1d>,
shortcut: Option<StreamableConv1d>,
activation: candle_nn::Activation,
skip_op: candle::StreamingBinOp,
span: tracing::Span,
}
impl SeaNetResnetBlock {
#[allow(clippy::too_many_arguments)]
pub fn new(
dim: usize,
k_sizes_and_dilations: &[(usize, usize)],
activation: candle_nn::Activation,
norm: Option<super::conv::Norm>,
causal: bool,
pad_mode: super::conv::PadMode,
compress: usize,
true_skip: bool,
vb: VarBuilder,
) -> Result<Self> {
let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
let hidden = dim / compress;
let vb_b = vb.pp("block");
for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
let in_c = if i == 0 { dim } else { hidden };
let out_c = if i == k_sizes_and_dilations.len() - 1 {
dim
} else {
hidden
};
let c = StreamableConv1d::new(
in_c,
out_c,
/* k_size */ *k_size,
/* stride */ 1,
/* dilation */ *dilation,
/* groups */ 1,
/* bias */ true,
/* causal */ causal,
/* norm */ norm,
/* pad_mode */ pad_mode,
vb_b.pp(2 * i + 1),
)?;
block.push(c)
}
let shortcut = if true_skip {
None
} else {
let c = StreamableConv1d::new(
dim,
dim,
/* k_size */ 1,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ causal,
/* norm */ norm,
/* pad_mode */ pad_mode,
vb.pp("shortcut"),
)?;
Some(c)
};
Ok(Self {
block,
shortcut,
activation,
skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
})
}
}
impl Module for SeaNetResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut ys = xs.clone();
for block in self.block.iter() {
ys = ys.apply(&self.activation)?.apply(block)?;
}
match self.shortcut.as_ref() {
None => ys + xs,
Some(shortcut) => ys + xs.apply(shortcut),
}
}
}
impl StreamingModule for SeaNetResnetBlock {
fn reset_state(&mut self) {
for block in self.block.iter_mut() {
block.reset_state()
}
if let Some(shortcut) = self.shortcut.as_mut() {
shortcut.reset_state()
}
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let mut ys = xs.clone();
for block in self.block.iter_mut() {
ys = block.step(&ys.apply(&self.activation)?)?;
}
match self.shortcut.as_ref() {
None => self.skip_op.step(&ys, xs),
Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?),
}
}
}
#[derive(Debug, Clone)]
struct EncoderLayer {
residuals: Vec<SeaNetResnetBlock>,
downsample: StreamableConv1d,
}
#[derive(Debug, Clone)]
pub struct SeaNetEncoder {
init_conv1d: StreamableConv1d,
activation: candle_nn::Activation,
layers: Vec<EncoderLayer>,
final_conv1d: StreamableConv1d,
span: tracing::Span,
}
impl SeaNetEncoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
if cfg.lstm > 0 {
candle::bail!("seanet lstm is not supported")
}
let n_blocks = 2 + cfg.ratios.len();
let mut mult = 1usize;
let init_norm = if cfg.disable_norm_outer_blocks >= 1 {
None
} else {
Some(cfg.norm)
};
let mut layer_idx = 0;
let vb = vb.pp("layers");
let init_conv1d = StreamableConv1d::new(
cfg.channels,
mult * cfg.n_filters,
cfg.kernel_size,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ cfg.causal,
/* norm */ init_norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx),
)?;
layer_idx += 1;
let mut layers = Vec::with_capacity(cfg.ratios.len());
for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
let norm = if cfg.disable_norm_outer_blocks >= i + 2 {
None
} else {
Some(cfg.norm)
};
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
for j in 0..cfg.n_residual_layers {
let resnet_block = SeaNetResnetBlock::new(
mult * cfg.n_filters,
&[
(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
(1, 1),
],
cfg.activation,
norm,
cfg.causal,
cfg.pad_mode,
cfg.compress,
cfg.true_skip,
vb.pp(layer_idx),
)?;
residuals.push(resnet_block);
layer_idx += 1;
}
let downsample = StreamableConv1d::new(
mult * cfg.n_filters,
mult * cfg.n_filters * 2,
/* k_size */ ratio * 2,
/* stride */ ratio,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ true,
/* norm */ norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx + 1),
)?;
layer_idx += 2;
let layer = EncoderLayer {
downsample,
residuals,
};
layers.push(layer);
mult *= 2
}
let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks {
None
} else {
Some(cfg.norm)
};
let final_conv1d = StreamableConv1d::new(
mult * cfg.n_filters,
cfg.dimension,
cfg.last_kernel_size,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ cfg.causal,
/* norm */ final_norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx + 1),
)?;
Ok(Self {
init_conv1d,
activation: cfg.activation,
layers,
final_conv1d,
span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
})
}
}
impl Module for SeaNetEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.apply(&self.init_conv1d)?;
for layer in self.layers.iter() {
for residual in layer.residuals.iter() {
xs = xs.apply(residual)?
}
xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
}
xs.apply(&self.activation)?.apply(&self.final_conv1d)
}
}
impl StreamingModule for SeaNetEncoder {
fn reset_state(&mut self) {
self.init_conv1d.reset_state();
self.layers.iter_mut().for_each(|v| {
v.residuals.iter_mut().for_each(|v| v.reset_state());
v.downsample.reset_state()
});
self.final_conv1d.reset_state();
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let mut xs = self.init_conv1d.step(xs)?;
for layer in self.layers.iter_mut() {
for residual in layer.residuals.iter_mut() {
xs = residual.step(&xs)?;
}
xs = layer.downsample.step(&xs.apply(&self.activation)?)?;
}
self.final_conv1d.step(&xs.apply(&self.activation)?)
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
upsample: StreamableConvTranspose1d,
residuals: Vec<SeaNetResnetBlock>,
}
#[derive(Debug, Clone)]
pub struct SeaNetDecoder {
init_conv1d: StreamableConv1d,
activation: candle_nn::Activation,
layers: Vec<DecoderLayer>,
final_conv1d: StreamableConv1d,
final_activation: Option<candle_nn::Activation>,
span: tracing::Span,
}
impl SeaNetDecoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
if cfg.lstm > 0 {
candle::bail!("seanet lstm is not supported")
}
let n_blocks = 2 + cfg.ratios.len();
let mut mult = 1 << cfg.ratios.len();
let init_norm = if cfg.disable_norm_outer_blocks == n_blocks {
None
} else {
Some(cfg.norm)
};
let mut layer_idx = 0;
let vb = vb.pp("layers");
let init_conv1d = StreamableConv1d::new(
cfg.dimension,
mult * cfg.n_filters,
cfg.kernel_size,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ cfg.causal,
/* norm */ init_norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx),
)?;
layer_idx += 1;
let mut layers = Vec::with_capacity(cfg.ratios.len());
for (i, &ratio) in cfg.ratios.iter().enumerate() {
let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
None
} else {
Some(cfg.norm)
};
let upsample = StreamableConvTranspose1d::new(
mult * cfg.n_filters,
mult * cfg.n_filters / 2,
/* k_size */ ratio * 2,
/* stride */ ratio,
/* groups */ 1,
/* bias */ true,
/* causal */ true,
/* norm */ norm,
vb.pp(layer_idx + 1),
)?;
layer_idx += 2;
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
for j in 0..cfg.n_residual_layers {
let resnet_block = SeaNetResnetBlock::new(
mult * cfg.n_filters / 2,
&[
(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
(1, 1),
],
cfg.activation,
norm,
cfg.causal,
cfg.pad_mode,
cfg.compress,
cfg.true_skip,
vb.pp(layer_idx),
)?;
residuals.push(resnet_block);
layer_idx += 1;
}
let layer = DecoderLayer {
upsample,
residuals,
};
layers.push(layer);
mult /= 2
}
let final_norm = if cfg.disable_norm_outer_blocks >= 1 {
None
} else {
Some(cfg.norm)
};
let final_conv1d = StreamableConv1d::new(
cfg.n_filters,
cfg.channels,
cfg.last_kernel_size,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ cfg.causal,
/* norm */ final_norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx + 1),
)?;
Ok(Self {
init_conv1d,
activation: cfg.activation,
layers,
final_conv1d,
final_activation: cfg.final_activation,
span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
})
}
}
impl Module for SeaNetDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.apply(&self.init_conv1d)?;
for layer in self.layers.iter() {
xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
for residual in layer.residuals.iter() {
xs = xs.apply(residual)?
}
}
let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
let xs = match self.final_activation.as_ref() {
None => xs,
Some(act) => xs.apply(act)?,
};
Ok(xs)
}
}
impl StreamingModule for SeaNetDecoder {
fn reset_state(&mut self) {
self.init_conv1d.reset_state();
self.layers.iter_mut().for_each(|v| {
v.residuals.iter_mut().for_each(|v| v.reset_state());
v.upsample.reset_state()
});
self.final_conv1d.reset_state();
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let mut xs = self.init_conv1d.step(xs)?;
for layer in self.layers.iter_mut() {
xs = layer.upsample.step(&xs.apply(&self.activation)?)?;
for residual in layer.residuals.iter_mut() {
xs = residual.step(&xs)?;
}
}
let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?;
let xs = match self.final_activation.as_ref() {
None => xs,
Some(act) => xs.apply(act)?,
};
Ok(xs)
}
}

View File

@ -0,0 +1,802 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use candle::{DType, Device, IndexOp, Module, Result, StreamTensor, StreamingModule, Tensor, D};
use candle_nn::{linear_no_bias, Linear, VarBuilder};
use std::sync::Arc;
fn linear(in_d: usize, out_d: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
if bias {
candle_nn::linear(in_d, out_d, vb)
} else {
linear_no_bias(in_d, out_d, vb)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum PositionalEmbedding {
Rope,
Sin,
None,
}
#[derive(Debug, Clone)]
pub struct Config {
pub d_model: usize,
pub num_heads: usize,
pub num_layers: usize,
pub causal: bool,
pub norm_first: bool,
pub bias_ff: bool,
pub bias_attn: bool,
pub layer_scale: Option<f64>,
pub positional_embedding: PositionalEmbedding,
pub use_conv_block: bool,
pub cross_attention: bool,
pub conv_kernel_size: usize,
pub use_conv_bias: bool,
pub gating: Option<candle_nn::Activation>,
pub norm: super::NormType,
pub context: usize,
pub max_period: usize,
pub max_seq_len: usize,
pub kv_repeat: usize,
pub dim_feedforward: usize,
pub conv_layout: bool,
}
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
span: tracing::Span,
}
impl RotaryEmbedding {
pub fn new(dim: usize, max_seq_len: usize, theta: f32, dev: &Device) -> Result<Self> {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
span: tracing::span!(tracing::Level::TRACE, "rot"),
})
}
pub fn apply_rotary_emb(&self, qk: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (_b_size, _nheads, seqlen, _headdim) = qk.dims4()?;
let qk_dtype = qk.dtype();
let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &c, &s)?.to_dtype(qk_dtype)
}
}
#[derive(Debug, Clone)]
pub struct LayerScale {
scale: Tensor,
}
impl LayerScale {
pub fn new(d_model: usize, _init: f64, vb: VarBuilder) -> Result<Self> {
let scale = vb.get(d_model, "scale")?;
Ok(Self { scale })
}
}
impl Module for LayerScale {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&self.scale)
}
}
pub(crate) fn get_mask(
size1: usize,
size2: usize,
context: usize,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2)
.map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
#[derive(Debug, Clone)]
pub struct StreamingMultiheadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
kv_repeat: usize,
num_heads: usize,
context: usize,
neg_inf: Tensor,
rope: Option<Arc<RotaryEmbedding>>,
kv_cache: candle_nn::kv_cache::KvCache,
pos: usize,
use_flash_attn: bool,
span: tracing::Span,
}
impl StreamingMultiheadAttention {
pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.d_model;
let num_kv = cfg.num_heads / cfg.kv_repeat;
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
let q_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("q_proj"))?;
let k_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("k_proj"))?;
let v_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("v_proj"))?;
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("o_proj"))?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
rope: rope.clone(),
kv_repeat: cfg.kv_repeat,
num_heads: cfg.num_heads,
context: cfg.context,
neg_inf,
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
pos: 0,
use_flash_attn: false,
span: tracing::span!(tracing::Level::TRACE, "mha"),
})
}
pub fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
if self.kv_repeat != 1 {
candle::bail!("only kv-repeat = 1 is supported")
}
let (b, t, hd) = xs.dims3()?;
let head_dim = hd / self.num_heads;
let q = xs
.apply(&self.q_proj)?
.reshape((b, t, self.num_heads, head_dim))?;
let k = xs
.apply(&self.k_proj)?
.reshape((b, t, self.num_heads, head_dim))?;
let v = xs
.apply(&self.v_proj)?
.reshape((b, t, self.num_heads, head_dim))?;
// qk_layer_norm = None
// kv_repeat = 1, otherwise we would need repeat_kv
let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
if let Some(rope) = &self.rope {
q = rope.apply_rotary_emb(&q, self.pos)?;
k = rope.apply_rotary_emb(&k, self.pos)?;
}
let (k, v) = {
self.pos += k.dim(2)?;
self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?
};
// The KV cache keeps all the data at the moment, we want to trim
// down the part that comes from the cache to at most context to
// be coherent with the mask shape we provide.
let k_len = k.dim(2)?;
let k_target_len = t + usize::min(self.context, k_len - t);
let (k, v) = if k_target_len < k_len {
let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
(k, v)
} else {
(k.clone(), v.clone())
};
let xs = if q.dtype() == DType::BF16 && self.use_flash_attn {
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (head_dim as f32).sqrt();
flash_attn(&q, &k, &v, softmax_scale, t > 1)?.transpose(1, 2)?
} else {
let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
let pre_ws = match mask {
None => pre_ws,
Some(mask) => {
let mask = mask.broadcast_left((b, self.num_heads))?;
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
mask.where_cond(&neg_inf, &pre_ws)?
}
};
let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
ws.matmul(&v)? // b,h,t,d
};
let xs = xs
.transpose(1, 2)? // b,t,h,d
.reshape((b, t, hd))?
.apply(&self.out_proj)?;
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
self.kv_cache.reset()
}
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
self.kv_cache = kv_cache
}
}
#[derive(Debug, Clone)]
pub struct StreamingMultiheadCrossAttention {
in_proj_q: Linear,
in_proj_k: Linear,
in_proj_v: Linear,
out_proj: Linear,
kv_repeat: usize,
num_heads: usize,
neg_inf: Tensor,
span: tracing::Span,
}
impl StreamingMultiheadCrossAttention {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.d_model;
let num_kv = cfg.num_heads / cfg.kv_repeat;
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
let out_dim = embed_dim + 2 * kv_dim;
let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?;
let in_proj_weight_k = in_proj_weight.narrow(0, embed_dim, kv_dim)?;
let in_proj_weight_v = in_proj_weight.narrow(0, embed_dim + kv_dim, kv_dim)?;
let (in_proj_bias_q, in_proj_bias_k, in_proj_bias_v) = if cfg.bias_attn {
let b = vb.get(out_dim, "in_proj_bias")?;
let q = b.narrow(0, 0, embed_dim)?;
let k = b.narrow(0, embed_dim, kv_dim)?;
let v = b.narrow(0, embed_dim + kv_dim, kv_dim)?;
(Some(q), Some(k), Some(v))
} else {
(None, None, None)
};
let in_proj_q = Linear::new(in_proj_weight_q, in_proj_bias_q);
let in_proj_k = Linear::new(in_proj_weight_k, in_proj_bias_k);
let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v);
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
Ok(Self {
in_proj_q,
in_proj_k,
in_proj_v,
out_proj,
kv_repeat: cfg.kv_repeat,
num_heads: cfg.num_heads,
neg_inf,
span: tracing::span!(tracing::Level::TRACE, "mhca"),
})
}
pub fn forward(&self, xs: &Tensor, ca_src: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
if self.kv_repeat != 1 {
candle::bail!("only kv-repeat = 1 is supported")
}
let (b, t, hd) = xs.dims3()?;
let head_dim = hd / self.num_heads;
// time_dim = 1, layout: b,t,h,d
let q = xs.apply(&self.in_proj_q)?;
let k = ca_src.apply(&self.in_proj_k)?;
let v = ca_src.apply(&self.in_proj_v)?;
let (ca_b, ca_t, ca_dim) = k.dims3()?;
let q = q.reshape((b, t, self.num_heads, head_dim))?;
let k = k.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
let v = v.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
// qk_layer_norm = None
// kv_repeat = 1, otherwise we would need repeat_kv
let q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
let k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
let pre_ws = match mask {
None => pre_ws,
Some(mask) => {
let mask = mask.broadcast_left((b, self.num_heads))?;
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
mask.where_cond(&neg_inf, &pre_ws)?
}
};
let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
let xs = ws.matmul(&v)?; // b,h,t,d
let xs = xs
.transpose(1, 2)? // b,t,h,d
.reshape((b, t, hd))?
.apply(&self.out_proj)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub enum Mlp {
NoGating {
span1: tracing::Span,
linear1: Linear,
span2: tracing::Span,
linear2: Linear,
span: tracing::Span,
},
Gating {
linear_in: Linear,
linear_out: Linear,
activation: candle_nn::Activation,
span: tracing::Span,
},
}
impl Mlp {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let d_model = cfg.d_model;
let span = tracing::span!(tracing::Level::TRACE, "mlp");
match cfg.gating {
None => {
let span1 = tracing::span!(tracing::Level::TRACE, "lin1");
let span2 = tracing::span!(tracing::Level::TRACE, "lin2");
let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("mlp.fc1"))?;
let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("mlp.fc2"))?;
Ok(Self::NoGating {
linear1,
linear2,
span,
span1,
span2,
})
}
Some(activation) => {
let vb = vb.pp("gating");
let hidden = if cfg.dim_feedforward == 4 * d_model {
11 * d_model / 4
} else {
2 * cfg.dim_feedforward / 3
};
// TODO: Maybe use bias_ff here?
let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?;
let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?;
Ok(Self::Gating {
linear_in,
linear_out,
activation,
span,
})
}
}
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::NoGating {
linear1,
linear2,
span,
span1,
span2,
} => {
let _enter = span.enter();
let xs = {
let _enter = span1.enter();
xs.apply(linear1)?
};
let xs = xs.gelu_erf()?;
{
let _enter = span2.enter();
xs.apply(linear2)
}
}
Self::Gating {
linear_in,
linear_out,
activation,
span,
} => {
let _enter = span.enter();
let xs = xs.apply(linear_in)?;
let (b, t, _) = xs.dims3()?;
let xs = xs.reshape((b, t, 2, ()))?;
let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
xs.apply(linear_out)
}
}
}
}
#[derive(Debug, Clone)]
pub struct RmsNorm {
pub(crate) alpha: Tensor,
pub(crate) eps: f32,
}
impl RmsNorm {
pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result<Self> {
let alpha = vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?;
Ok(Self { alpha, eps })
}
}
impl Module for RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
}
}
#[derive(Debug, Clone)]
pub enum Norm {
LayerNorm(candle_nn::LayerNorm),
RmsNorm(RmsNorm),
}
impl Norm {
pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm = match cfg.norm {
super::NormType::LayerNorm => {
let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?;
Self::LayerNorm(norm)
}
super::NormType::RmsNorm => {
let norm = RmsNorm::new(d_model, 1e-8, vb)?;
Self::RmsNorm(norm)
}
};
Ok(norm)
}
}
impl Module for Norm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::LayerNorm(m) => m.forward(xs),
Self::RmsNorm(m) => m.forward(xs),
}
}
}
#[derive(Debug, Clone)]
pub struct StreamingTransformerLayer {
self_attn: StreamingMultiheadAttention,
mlp: Mlp,
norm1: Norm,
norm2: Norm,
layer_scale_1: Option<LayerScale>,
layer_scale_2: Option<LayerScale>,
cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>,
norm_first: bool,
span: tracing::Span,
}
impl StreamingTransformerLayer {
pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
if cfg.use_conv_block {
candle::bail!("conv-block is not supported")
}
let d_model = cfg.d_model;
let mlp = Mlp::new(cfg, vb.clone())?;
let (norm1, norm2) = match cfg.norm {
super::NormType::LayerNorm => {
let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("input_layernorm"))?;
let norm2 =
candle_nn::layer_norm(d_model, 1e-5, vb.pp("post_attention_layernorm"))?;
(Norm::LayerNorm(norm1), Norm::LayerNorm(norm2))
}
super::NormType::RmsNorm => {
let norm1 = RmsNorm::new(d_model, 1e-8, vb.pp("input_rmsnorm"))?;
let norm2 = RmsNorm::new(d_model, 1e-8, vb.pp("post_attention_rmsnorm"))?;
(Norm::RmsNorm(norm1), Norm::RmsNorm(norm2))
}
};
let layer_scale_1 = match cfg.layer_scale {
None => None,
Some(ls) => {
let ls = LayerScale::new(d_model, ls, vb.pp("self_attn_layer_scale"))?;
Some(ls)
}
};
let layer_scale_2 = match cfg.layer_scale {
None => None,
Some(ls) => {
let ls = LayerScale::new(d_model, ls, vb.pp("mlp_layer_scale"))?;
Some(ls)
}
};
let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
let cross_attn = if cfg.cross_attention {
let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
} else {
None
};
Ok(Self {
self_attn,
mlp,
norm1,
norm2,
layer_scale_1,
layer_scale_2,
cross_attn,
norm_first: cfg.norm_first,
span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
})
}
pub fn forward(
&mut self,
xs: &Tensor,
ca_src: Option<&Tensor>,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
if !self.norm_first {
candle::bail!("only norm_first = true is supported")
}
let norm1 = xs.apply(&self.norm1)?;
let xs = (xs
+ self
.self_attn
.forward(&norm1, mask)?
.apply(&self.layer_scale_1.as_ref())?)?;
let xs = match (&self.cross_attn, ca_src) {
(Some((norm_cross, cross_attn)), Some(ca_src)) => {
let residual = &xs;
let xs = xs.apply(norm_cross)?;
(residual + cross_attn.forward(&xs, ca_src, None)?)?
}
_ => xs,
};
let xs = (&xs
+ xs.apply(&self.norm2)?
.apply(&self.mlp)?
.apply(&self.layer_scale_2.as_ref()))?;
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache()
}
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
self.self_attn.set_kv_cache(kv_cache)
}
}
#[derive(Debug, Clone)]
pub struct StreamingTransformer {
layers: Vec<StreamingTransformerLayer>,
context: usize,
positional_embedding: PositionalEmbedding,
max_period: usize,
}
impl StreamingTransformer {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_l = vb.pp("layers");
let rope = match cfg.positional_embedding {
PositionalEmbedding::Rope => {
let rope = RotaryEmbedding::new(
cfg.d_model / cfg.num_heads,
cfg.max_seq_len,
cfg.max_period as f32,
vb.device(),
)?;
Some(Arc::new(rope))
}
PositionalEmbedding::Sin | PositionalEmbedding::None => None,
};
let mut layers = Vec::with_capacity(cfg.num_layers);
for layer_idx in 0..cfg.num_layers {
let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
Ok(Self {
layers,
context: cfg.context,
positional_embedding: cfg.positional_embedding,
max_period: cfg.max_period,
})
}
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
self.forward_ca(xs, None)
}
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
let (_b, t, c) = xs.dims3()?;
// We will extract at most "context" from the kv_cache.
// Note that the mask will discard the values that are before context.
let pos = self.layers[0]
.self_attn
.kv_cache
.k_cache()
.current_seq_len()
.min(self.context);
let mask = if t == 1 {
None
} else {
Some(get_mask(t, pos + t, self.context, xs.device())?)
};
let mut xs = match self.positional_embedding {
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
PositionalEmbedding::Sin => {
let dev = xs.device();
let theta = self.max_period as f32;
let half_dim = c / 2;
let positions = Tensor::arange(pos as u32, (pos + t) as u32, dev)?
.unsqueeze(1)?
.to_dtype(DType::F32)?;
let inv_freq: Vec<_> = (0..half_dim)
.map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let freqs = positions.broadcast_mul(&inv_freq)?;
let pos_emb =
Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?.to_dtype(xs.dtype())?;
xs.broadcast_add(&pos_emb)?
}
};
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, ca_src, mask.as_ref())?;
}
Ok(xs)
}
pub fn copy_state(&mut self, from: &Self) -> Result<()> {
if self.layers.len() != from.layers.len() {
candle::bail!("cannot copy kv-caches as the transformers have different depths")
}
self.layers
.iter_mut()
.zip(from.layers.iter())
.for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
Ok(())
}
}
impl StreamingModule for StreamingTransformer {
fn reset_state(&mut self) {
self.layers.iter_mut().for_each(|v| v.reset_kv_cache())
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
match xs.as_option() {
None => Ok(StreamTensor::empty()),
Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)),
}
}
}
#[derive(Debug, Clone)]
pub struct ProjectedTransformer {
transformer: StreamingTransformer,
input_proj: Option<Linear>,
output_projs: Vec<Option<Linear>>,
conv_layout: bool,
span: tracing::Span,
}
impl ProjectedTransformer {
pub fn new(
input_dim: usize,
output_dims: &[usize],
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let transformer = StreamingTransformer::new(cfg, vb.clone())?;
let input_proj = if input_dim == cfg.d_model {
None
} else {
let l = linear_no_bias(input_dim, cfg.d_model, vb.pp("input_proj"))?;
Some(l)
};
let mut output_projs = Vec::with_capacity(output_dims.len());
let vb_o = vb.pp("output_projs");
for (i, &output_dim) in output_dims.iter().enumerate() {
let output_proj = if output_dim == cfg.d_model {
None
} else {
let l = linear_no_bias(cfg.d_model, output_dim, vb_o.pp(i))?;
Some(l)
};
output_projs.push(output_proj)
}
Ok(Self {
transformer,
input_proj,
output_projs,
conv_layout: cfg.conv_layout,
span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
})
}
pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
let _enter = self.span.enter();
let xs = if self.conv_layout {
xs.transpose(1, 2)?
} else {
xs.clone()
};
let xs = xs.apply(&self.input_proj.as_ref())?;
let xs = self.transformer.forward(&xs)?;
let mut ys = Vec::with_capacity(self.output_projs.len());
for output_proj in self.output_projs.iter() {
let ys_ = xs.apply(&output_proj.as_ref())?;
let ys_ = if self.conv_layout {
ys_.transpose(1, 2)?
} else {
ys_
};
ys.push(ys_)
}
Ok(ys)
}
}
impl StreamingModule for ProjectedTransformer {
fn reset_state(&mut self) {
self.transformer.reset_state()
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let xs = xs.apply(&|x: &Tensor| {
if self.conv_layout {
x.transpose(1, 2)
} else {
Ok(x.clone())
}
})?;
let xs = xs.apply(&self.input_proj.as_ref())?;
let xs = self.transformer.step(&xs)?;
let ys = xs.apply(&self.output_projs[0].as_ref())?;
ys.apply(&|y: &Tensor| {
if self.conv_layout {
y.transpose(1, 2)
} else {
Ok(y.clone())
}
})
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}

View File

@ -33,6 +33,7 @@ pub mod llava;
pub mod mamba;
pub mod marian;
pub mod metavoice;
pub mod mimi;
pub mod mistral;
pub mod mixformer;
pub mod mixtral;