mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Compare commits
1 Commits
0.9.0-alph
...
conv1d-alg
Author | SHA1 | Date | |
---|---|---|---|
8e62723b2d |
18
Cargo.toml
18
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.2"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.2" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.3" }
|
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.2" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.3" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.2" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.3" }
|
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.2" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.3" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.2" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.3" }
|
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.2" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.3" }
|
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.2" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.3" }
|
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.2" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
|
@ -6,18 +6,28 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
|
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||||
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||||
|
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
|
drop(_x1);
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
|
device.synchronize()?;
|
||||||
|
println!("fp32: {:?}", start_time.elapsed());
|
||||||
|
drop(_x1);
|
||||||
|
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||||
|
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
|
drop(_x1);
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
|
device.synchronize()?;
|
||||||
|
println!("tf32: {:?}", start_time.elapsed());
|
||||||
drop(_x1);
|
drop(_x1);
|
||||||
for _ in 0..20 {
|
|
||||||
let start_time = std::time::Instant::now();
|
|
||||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
|
||||||
device.synchronize()?;
|
|
||||||
println!("conv1d: {:?}", start_time.elapsed());
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ impl ParamsConvTranspose1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum CudnnFwdAlgo {
|
pub enum CudnnFwdAlgo {
|
||||||
ImplicitGemm,
|
ImplicitGemm,
|
||||||
ImplicitPrecompGemm,
|
ImplicitPrecompGemm,
|
||||||
@ -152,19 +152,6 @@ impl Tensor {
|
|||||||
stride: usize,
|
stride: usize,
|
||||||
dilation: usize,
|
dilation: usize,
|
||||||
groups: usize,
|
groups: usize,
|
||||||
) -> Result<Self> {
|
|
||||||
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a 1D convolution over the input tensor.
|
|
||||||
pub fn conv1d_with_algo(
|
|
||||||
&self,
|
|
||||||
kernel: &Self,
|
|
||||||
padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
groups: usize,
|
|
||||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||||
let (b_size, c_in, l_in) = self.dims3()?;
|
let (b_size, c_in, l_in) = self.dims3()?;
|
||||||
@ -188,7 +175,7 @@ impl Tensor {
|
|||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
cudnn_fwd_algo,
|
cudnn_fwd_algo: Some(CudnnFwdAlgo::ImplicitGemm),
|
||||||
};
|
};
|
||||||
if groups == 1 {
|
if groups == 1 {
|
||||||
self.conv1d_single_group(kernel, ¶ms)
|
self.conv1d_single_group(kernel, ¶ms)
|
||||||
@ -293,18 +280,6 @@ impl Tensor {
|
|||||||
stride: usize,
|
stride: usize,
|
||||||
dilation: usize,
|
dilation: usize,
|
||||||
groups: usize,
|
groups: usize,
|
||||||
) -> Result<Self> {
|
|
||||||
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv2d_with_algo(
|
|
||||||
&self,
|
|
||||||
kernel: &Self,
|
|
||||||
padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
groups: usize,
|
|
||||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||||
@ -324,7 +299,7 @@ impl Tensor {
|
|||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
cudnn_fwd_algo,
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
if groups == 1 {
|
if groups == 1 {
|
||||||
self.conv2d_single_group(kernel, ¶ms)
|
self.conv2d_single_group(kernel, ¶ms)
|
||||||
|
@ -46,7 +46,7 @@ impl TextGeneration {
|
|||||||
Sampling::ArgMax
|
Sampling::ArgMax
|
||||||
} else {
|
} else {
|
||||||
match (top_k, top_p) {
|
match (top_k, top_p) {
|
||||||
(None, None) => Sampling::GumbelSoftmax { temperature },
|
(None, None) => Sampling::All { temperature },
|
||||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
@ -133,7 +133,6 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||||||
padding,
|
padding,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv = if bias {
|
let conv = if bias {
|
||||||
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||||
|
@ -92,7 +92,6 @@ impl ConvBlock {
|
|||||||
stride,
|
stride,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.2"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.3" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.2" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.2"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.2"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
//! Convolution Layers.
|
//! Convolution Layers.
|
||||||
use crate::BatchNorm;
|
use crate::BatchNorm;
|
||||||
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub struct Conv1dConfig {
|
pub struct Conv1dConfig {
|
||||||
@ -8,7 +8,6 @@ pub struct Conv1dConfig {
|
|||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub dilation: usize,
|
pub dilation: usize,
|
||||||
pub groups: usize,
|
pub groups: usize,
|
||||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Conv1dConfig {
|
impl Default for Conv1dConfig {
|
||||||
@ -18,7 +17,6 @@ impl Default for Conv1dConfig {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -54,13 +52,12 @@ impl Conv1d {
|
|||||||
|
|
||||||
impl crate::Module for Conv1d {
|
impl crate::Module for Conv1d {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.conv1d_with_algo(
|
let x = x.conv1d(
|
||||||
&self.weight,
|
&self.weight,
|
||||||
self.config.padding,
|
self.config.padding,
|
||||||
self.config.stride,
|
self.config.stride,
|
||||||
self.config.dilation,
|
self.config.dilation,
|
||||||
self.config.groups,
|
self.config.groups,
|
||||||
self.config.cudnn_fwd_algo,
|
|
||||||
)?;
|
)?;
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
@ -150,7 +147,6 @@ pub struct Conv2dConfig {
|
|||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub dilation: usize,
|
pub dilation: usize,
|
||||||
pub groups: usize,
|
pub groups: usize,
|
||||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Conv2dConfig {
|
impl Default for Conv2dConfig {
|
||||||
@ -160,7 +156,6 @@ impl Default for Conv2dConfig {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -216,13 +211,12 @@ impl Conv2d {
|
|||||||
|
|
||||||
impl crate::Module for Conv2d {
|
impl crate::Module for Conv2d {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.conv2d_with_algo(
|
let x = x.conv2d(
|
||||||
&self.weight,
|
&self.weight,
|
||||||
self.config.padding,
|
self.config.padding,
|
||||||
self.config.stride,
|
self.config.stride,
|
||||||
self.config.dilation,
|
self.config.dilation,
|
||||||
self.config.groups,
|
self.config.groups,
|
||||||
self.config.cudnn_fwd_algo,
|
|
||||||
)?;
|
)?;
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
|
@ -31,7 +31,6 @@ pub mod ops;
|
|||||||
pub mod optim;
|
pub mod optim;
|
||||||
pub mod rnn;
|
pub mod rnn;
|
||||||
pub mod rotary_emb;
|
pub mod rotary_emb;
|
||||||
pub mod sampling;
|
|
||||||
pub mod sequential;
|
pub mod sequential;
|
||||||
pub mod var_builder;
|
pub mod var_builder;
|
||||||
pub mod var_map;
|
pub mod var_map;
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
use candle::{Result, Tensor};
|
|
||||||
|
|
||||||
/// Sample according to the Gumbel-Softmax distribution.
|
|
||||||
pub fn gumbel_softmax<D: candle::shape::Dim>(
|
|
||||||
logits: &Tensor,
|
|
||||||
temperature: f64,
|
|
||||||
dim: D,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
if temperature <= 0.0 {
|
|
||||||
logits.argmax(dim)
|
|
||||||
} else if temperature == 1.0 {
|
|
||||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
|
||||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
|
||||||
Ok(sampled)
|
|
||||||
} else {
|
|
||||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
|
||||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
|
||||||
Ok(sampled)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.2"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
|
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.2" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.3" }
|
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.2" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -13,8 +13,6 @@ pub enum Sampling {
|
|||||||
TopK { k: usize, temperature: f64 },
|
TopK { k: usize, temperature: f64 },
|
||||||
TopP { p: f64, temperature: f64 },
|
TopP { p: f64, temperature: f64 },
|
||||||
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
||||||
// Note that the rng is not used for the Gumbel-Softmax sampling.
|
|
||||||
GumbelSoftmax { temperature: f64 },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitsProcessor {
|
pub struct LogitsProcessor {
|
||||||
@ -51,11 +49,6 @@ impl LogitsProcessor {
|
|||||||
Ok(next_token)
|
Ok(next_token)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
|
|
||||||
let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
|
|
||||||
sampled.to_vec0::<u32>()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||||
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||||
let next_token = distr.sample(&mut self.rng) as u32;
|
let next_token = distr.sample(&mut self.rng) as u32;
|
||||||
@ -134,9 +127,6 @@ impl LogitsProcessor {
|
|||||||
|
|
||||||
let next_token = match &self.sampling {
|
let next_token = match &self.sampling {
|
||||||
Sampling::ArgMax => self.sample_argmax(logits)?,
|
Sampling::ArgMax => self.sample_argmax(logits)?,
|
||||||
Sampling::GumbelSoftmax { temperature } => {
|
|
||||||
self.sample_gumbel_softmax(&logits, *temperature)?
|
|
||||||
}
|
|
||||||
Sampling::All { temperature } => {
|
Sampling::All { temperature } => {
|
||||||
let prs = prs(*temperature)?;
|
let prs = prs(*temperature)?;
|
||||||
self.sample_multinomial(&prs)?
|
self.sample_multinomial(&prs)?
|
||||||
|
@ -124,7 +124,6 @@ impl ResidualConvUnit {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv1 = conv2d(
|
let conv1 = conv2d(
|
||||||
conf.num_features,
|
conf.num_features,
|
||||||
@ -209,7 +208,6 @@ impl FeatureFusionBlock {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let output_conv = conv2d(
|
let output_conv = conv2d(
|
||||||
conf.num_features,
|
conf.num_features,
|
||||||
@ -260,7 +258,6 @@ impl Scratch {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let layer1_rn = conv2d_no_bias(
|
let layer1_rn = conv2d_no_bias(
|
||||||
@ -322,7 +319,6 @@ impl Scratch {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let output_conv1 = conv2d(
|
let output_conv1 = conv2d(
|
||||||
conf.num_features,
|
conf.num_features,
|
||||||
@ -429,7 +425,6 @@ impl DPTHead {
|
|||||||
stride: 2,
|
stride: 2,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
},
|
},
|
||||||
vb.pp("resize_layers").pp("3"),
|
vb.pp("resize_layers").pp("3"),
|
||||||
)?),
|
)?),
|
||||||
|
@ -468,7 +468,6 @@ impl EncodecConv1d {
|
|||||||
stride,
|
stride,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
},
|
},
|
||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
|
@ -267,7 +267,6 @@ impl StreamableConv1d {
|
|||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
groups,
|
groups,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
|
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
|
||||||
if k_size < stride {
|
if k_size < stride {
|
||||||
|
@ -68,7 +68,6 @@ impl ResnetBlock2D {
|
|||||||
padding: 1,
|
padding: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||||
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||||
@ -84,7 +83,6 @@ impl ResnetBlock2D {
|
|||||||
padding: 0,
|
padding: 0,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
Some(conv2d(
|
Some(conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -248,14 +248,12 @@ impl AudioEncoder {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let cfg2 = Conv1dConfig {
|
let cfg2 = Conv1dConfig {
|
||||||
padding: 1,
|
padding: 1,
|
||||||
stride: 2,
|
stride: 2,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||||
|
@ -244,14 +244,12 @@ impl AudioEncoder {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let cfg2 = Conv1dConfig {
|
let cfg2 = Conv1dConfig {
|
||||||
padding: 1,
|
padding: 1,
|
||||||
stride: 2,
|
stride: 2,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||||
|
@ -54,25 +54,3 @@ fn sample_with_top_k() -> Result<()> {
|
|||||||
assert_eq!(token, 2);
|
assert_eq!(token, 2);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn sample_gumbel() -> Result<()> {
|
|
||||||
let mut logits_process = LogitsProcessor::from_sampling(
|
|
||||||
42,
|
|
||||||
candle_transformers::generation::Sampling::GumbelSoftmax { temperature: 1.0 },
|
|
||||||
);
|
|
||||||
let logits = Tensor::new(&[-1.0, 0.0, 0.2, 1.0], &Device::Cpu)?;
|
|
||||||
let sm = candle_nn::ops::softmax(&logits, 0)?.to_vec1::<f64>()?;
|
|
||||||
let mut counts = vec![0f64; 4];
|
|
||||||
let samples = 100000;
|
|
||||||
for _ in 0..samples {
|
|
||||||
let token = logits_process.sample(&logits)?;
|
|
||||||
counts[token as usize] += 1f64 / samples as f64;
|
|
||||||
}
|
|
||||||
for i in 0..4 {
|
|
||||||
if (counts[i] - sm[i]).abs() > 0.05 {
|
|
||||||
panic!("pr mismatch {counts:?} {sm:?}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -98,7 +98,6 @@ impl ConvBlock {
|
|||||||
stride,
|
stride,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||||
|
Reference in New Issue
Block a user