mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Matmul (no batch, no strided, f32, f32 only) sort of done.
This commit is contained in:
@ -2,25 +2,27 @@ pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle::utils::{cuda_is_available, metal_is_available};
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn device(cpu: bool) -> Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else {
|
||||
if cuda_is_available(){
|
||||
if cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
}else if metal_is_available(){
|
||||
} else if metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
}else{
|
||||
#[cfg(all(target_os="macos", target_arch="aarch64"))]
|
||||
} else {
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
{
|
||||
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
|
||||
}
|
||||
#[cfg(not(all(target_os="macos", target_arch="aarch64")))]
|
||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||
{
|
||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||
println!(
|
||||
"Running on CPU, to run on GPU, build this example with `--features cuda`"
|
||||
);
|
||||
}
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
|
Reference in New Issue
Block a user