mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Switch the default to using the faster kernels. (#1978)
* Switch the default to using the faster kernels. * Add the force-dmmv flag.
This commit is contained in:
@ -10,7 +10,7 @@ pub struct QCudaStorage {
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(true);
|
||||
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
pub fn set_force_dmmv(f: bool) {
|
||||
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
|
||||
|
@ -196,6 +196,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// Use the slower dmmv cuda kernel.
|
||||
#[arg(long)]
|
||||
force_dmmv: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -203,6 +207,9 @@ fn main() -> Result<()> {
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
#[cfg(feature = "cuda")]
|
||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
|
@ -236,9 +236,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
gqa: Option<usize>,
|
||||
|
||||
/// Use the (experimental) fast cuda kernels.
|
||||
/// Use the slower dmmv cuda kernel.
|
||||
#[arg(long)]
|
||||
fast_cuda: bool,
|
||||
force_dmmv: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -347,7 +347,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
candle::quantized::cuda::set_force_dmmv(!args.fast_cuda);
|
||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||
|
||||
let temperature = if args.temperature == 0. {
|
||||
None
|
||||
|
Reference in New Issue
Block a user