mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Make it easier to use whisper samples from the repo. (#112)
* Make it easier to use samples from the repo. * Use f32 for accumulation in the f16/bf16 kernels.
This commit is contained in:
@ -197,8 +197,8 @@ impl Decoder {
|
||||
let (_, _, content_frames) = mel.shape().r3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
let start = std::time::Instant::now();
|
||||
while seek < content_frames {
|
||||
let start = std::time::Instant::now();
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
@ -214,7 +214,7 @@ impl Decoder {
|
||||
duration: segment_duration,
|
||||
dr,
|
||||
};
|
||||
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
|
||||
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
|
||||
segments.push(segment)
|
||||
}
|
||||
Ok(segments)
|
||||
@ -236,8 +236,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The input to be processed, in wav formats, will default to `jfk.wav`
|
||||
/// https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
||||
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
|
||||
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
|
||||
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
|
||||
#[arg(long)]
|
||||
input: Option<String>,
|
||||
|
||||
@ -286,19 +287,27 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let sample = if let Some(input) = args.input {
|
||||
if let Some(sample) = input.strip_prefix("sample:") {
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
&format!("samples_{sample}.wav"),
|
||||
)?
|
||||
} else {
|
||||
std::path::PathBuf::from(input)
|
||||
}
|
||||
} else {
|
||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
"samples_jfk.wav",
|
||||
)?
|
||||
};
|
||||
(
|
||||
api.get(&repo, "config.json")?,
|
||||
api.get(&repo, "tokenizer.json")?,
|
||||
api.get(&repo, "model.safetensors")?,
|
||||
if let Some(input) = args.input {
|
||||
std::path::PathBuf::from(input)
|
||||
} else {
|
||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
"samples_jfk.wav",
|
||||
)?
|
||||
},
|
||||
sample,
|
||||
)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename A>
|
||||
__device__ void conv1d(
|
||||
const size_t src_numel,
|
||||
const size_t l_out,
|
||||
@ -30,7 +30,7 @@ __device__ void conv1d(
|
||||
const size_t dst_l = dst_i % l_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
T d = 0;
|
||||
A d = 0;
|
||||
for (size_t offset = 0; offset < k_size; ++offset) {
|
||||
const size_t src_l_plus = stride * dst_l + offset;
|
||||
if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) {
|
||||
@ -38,15 +38,15 @@ __device__ void conv1d(
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2];
|
||||
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2];
|
||||
d += src[src_idx] * kernel[k_idx];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[dst_i] = d;
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
|
||||
#define CONV1D_OP(TYPENAME, FN_NAME) \
|
||||
#define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t num_dims, \
|
||||
@ -56,19 +56,19 @@ extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \
|
||||
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CONV1D_OP(__nv_bfloat16, conv1d_bf16)
|
||||
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CONV1D_OP(__half, conv1d_f16)
|
||||
CONV1D_OP(__half, float, conv1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, conv1d_f32)
|
||||
CONV1D_OP(double, conv1d_f64)
|
||||
CONV1D_OP(uint8_t, conv1d_u8)
|
||||
CONV1D_OP(uint32_t, conv1d_u32)
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
CONV1D_OP(double, double, conv1d_f64)
|
||||
CONV1D_OP(uint8_t, uint8_t, conv1d_u8)
|
||||
CONV1D_OP(uint32_t, uint32_t, conv1d_u32)
|
||||
|
||||
|
Reference in New Issue
Block a user