Make it easier to use whisper samples from the repo. (#112)

* Make it easier to use samples from the repo.

* Use f32 for accumulation in the f16/bf16 kernels.
This commit is contained in:
Laurent Mazare
2023-07-08 18:48:27 +01:00
committed by GitHub
parent eb64ad0d4d
commit c187f347bf
2 changed files with 34 additions and 25 deletions

View File

@ -197,8 +197,8 @@ impl Decoder {
let (_, _, content_frames) = mel.shape().r3()?; let (_, _, content_frames) = mel.shape().r3()?;
let mut seek = 0; let mut seek = 0;
let mut segments = vec![]; let mut segments = vec![];
let start = std::time::Instant::now();
while seek < content_frames { while seek < content_frames {
let start = std::time::Instant::now();
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES); let segment_size = usize::min(content_frames - seek, N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?; let mel_segment = mel.narrow(2, seek, segment_size)?;
@ -214,7 +214,7 @@ impl Decoder {
duration: segment_duration, duration: segment_duration,
dr, dr,
}; };
println!("{seek}: {segment:?} : Took {:?}", start.elapsed()); println!("{seek}: {segment:?}, in {:?}", start.elapsed());
segments.push(segment) segments.push(segment)
} }
Ok(segments) Ok(segments)
@ -236,8 +236,9 @@ struct Args {
#[arg(long)] #[arg(long)]
revision: Option<String>, revision: Option<String>,
/// The input to be processed, in wav formats, will default to `jfk.wav` /// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
/// https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav /// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
#[arg(long)] #[arg(long)]
input: Option<String>, input: Option<String>,
@ -286,19 +287,27 @@ fn main() -> Result<()> {
} else { } else {
let repo = Repo::with_revision(model_id, RepoType::Model, revision); let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?; let api = Api::new()?;
let sample = if let Some(input) = args.input {
if let Some(sample) = input.strip_prefix("sample:") {
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
&format!("samples_{sample}.wav"),
)?
} else {
std::path::PathBuf::from(input)
}
} else {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
"samples_jfk.wav",
)?
};
( (
api.get(&repo, "config.json")?, api.get(&repo, "config.json")?,
api.get(&repo, "tokenizer.json")?, api.get(&repo, "tokenizer.json")?,
api.get(&repo, "model.safetensors")?, api.get(&repo, "model.safetensors")?,
if let Some(input) = args.input { sample,
std::path::PathBuf::from(input)
} else {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
"samples_jfk.wav",
)?
},
) )
}; };
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

View File

@ -1,7 +1,7 @@
#include "cuda_utils.cuh" #include "cuda_utils.cuh"
#include<stdint.h> #include<stdint.h>
template <typename T> template <typename T, typename A>
__device__ void conv1d( __device__ void conv1d(
const size_t src_numel, const size_t src_numel,
const size_t l_out, const size_t l_out,
@ -30,7 +30,7 @@ __device__ void conv1d(
const size_t dst_l = dst_i % l_out; const size_t dst_l = dst_i % l_out;
const size_t src_idx0 = b_idx * src_s[0]; const size_t src_idx0 = b_idx * src_s[0];
T d = 0; A d = 0;
for (size_t offset = 0; offset < k_size; ++offset) { for (size_t offset = 0; offset < k_size; ++offset) {
const size_t src_l_plus = stride * dst_l + offset; const size_t src_l_plus = stride * dst_l + offset;
if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) { if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) {
@ -38,15 +38,15 @@ __device__ void conv1d(
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2]; const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2];
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2]; const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2];
d += src[src_idx] * kernel[k_idx]; d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
} }
} }
} }
dst[dst_i] = d; dst[dst_i] = static_cast<T>(d);
} }
#define CONV1D_OP(TYPENAME, FN_NAME) \ #define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \ extern "C" __global__ void FN_NAME( \
const size_t src_numel, \ const size_t src_numel, \
const size_t num_dims, \ const size_t num_dims, \
@ -56,19 +56,19 @@ extern "C" __global__ void FN_NAME( \
const TYPENAME *kernel, \ const TYPENAME *kernel, \
TYPENAME *dst \ TYPENAME *dst \
) { \ ) { \
conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \ conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, info, src, kernel, dst); \
} \ } \
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
CONV1D_OP(__nv_bfloat16, conv1d_bf16) CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
#endif #endif
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
CONV1D_OP(__half, conv1d_f16) CONV1D_OP(__half, float, conv1d_f16)
#endif #endif
CONV1D_OP(float, conv1d_f32) CONV1D_OP(float, float, conv1d_f32)
CONV1D_OP(double, conv1d_f64) CONV1D_OP(double, double, conv1d_f64)
CONV1D_OP(uint8_t, conv1d_u8) CONV1D_OP(uint8_t, uint8_t, conv1d_u8)
CONV1D_OP(uint32_t, conv1d_u32) CONV1D_OP(uint32_t, uint32_t, conv1d_u32)