pub mod audio; pub mod bs1770; pub mod coco_classes; pub mod imagenet; pub mod token_output_stream; pub mod wav; use candle::utils::{cuda_is_available, metal_is_available}; use candle::{Device, Result, Tensor}; pub fn device(cpu: bool) -> Result { if cpu { Ok(Device::Cpu) } else if cuda_is_available() { Ok(Device::new_cuda(0)?) } else if metal_is_available() { Ok(Device::new_metal(0)?) } else { #[cfg(all(target_os = "macos", target_arch = "aarch64"))] { println!( "Running on CPU, to run on GPU(metal), build this example with `--features metal`" ); } #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] { println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); } Ok(Device::Cpu) } } pub fn load_image>( p: P, resize_longest: Option, ) -> Result<(Tensor, usize, usize)> { let img = image::ImageReader::open(p)? .decode() .map_err(candle::Error::wrap)?; let (initial_h, initial_w) = (img.height() as usize, img.width() as usize); let img = match resize_longest { None => img, Some(resize_longest) => { let (height, width) = (img.height(), img.width()); let resize_longest = resize_longest as u32; let (height, width) = if height < width { let h = (resize_longest * height) / width; (h, resize_longest) } else { let w = (resize_longest * width) / height; (resize_longest, w) }; img.resize_exact(width, height, image::imageops::FilterType::CatmullRom) } }; let (height, width) = (img.height() as usize, img.width() as usize); let img = img.to_rgb8(); let data = img.into_raw(); let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?; Ok((data, initial_h, initial_w)) } pub fn load_image_and_resize>( p: P, width: usize, height: usize, ) -> Result { let img = image::ImageReader::open(p)? .decode() .map_err(candle::Error::wrap)? .resize_to_fill( width as u32, height as u32, image::imageops::FilterType::Triangle, ); let img = img.to_rgb8(); let data = img.into_raw(); Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1)) } /// Saves an image to disk using the image crate, this expects an input with shape /// (c, height, width). pub fn save_image>(img: &Tensor, p: P) -> Result<()> { let p = p.as_ref(); let (channel, height, width) = img.dims3()?; if channel != 3 { candle::bail!("save_image expects an input of shape (3, height, width)") } let img = img.permute((1, 2, 0))?.flatten_all()?; let pixels = img.to_vec1::()?; let image: image::ImageBuffer, Vec> = match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { Some(image) => image, None => candle::bail!("error saving image {p:?}"), }; image.save(p).map_err(candle::Error::wrap)?; Ok(()) } pub fn save_image_resize>( img: &Tensor, p: P, h: usize, w: usize, ) -> Result<()> { let p = p.as_ref(); let (channel, height, width) = img.dims3()?; if channel != 3 { candle::bail!("save_image expects an input of shape (3, height, width)") } let img = img.permute((1, 2, 0))?.flatten_all()?; let pixels = img.to_vec1::()?; let image: image::ImageBuffer, Vec> = match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { Some(image) => image, None => candle::bail!("error saving image {p:?}"), }; let image = image::DynamicImage::from(image); let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom); image.save(p).map_err(candle::Error::wrap)?; Ok(()) } /// Loads the safetensors files for a model from the hub based on a json index file. pub fn hub_load_safetensors( repo: &hf_hub::api::sync::ApiRepo, json_file: &str, ) -> Result> { let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; let json_file = std::fs::File::open(json_file)?; let json: serde_json::Value = serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; let weight_map = match json.get("weight_map") { None => candle::bail!("no weight map in {json_file:?}"), Some(serde_json::Value::Object(map)) => map, Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), }; let mut safetensors_files = std::collections::HashSet::new(); for value in weight_map.values() { if let Some(file) = value.as_str() { safetensors_files.insert(file.to_string()); } } let safetensors_files = safetensors_files .iter() .map(|v| repo.get(v).map_err(candle::Error::wrap)) .collect::>>()?; Ok(safetensors_files) } pub fn hub_load_local_safetensors>( path: P, json_file: &str, ) -> Result> { let path = path.as_ref(); let jsfile = std::fs::File::open(path.join(json_file))?; let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle::Error::wrap)?; let weight_map = match json.get("weight_map") { None => candle::bail!("no weight map in {json_file:?}"), Some(serde_json::Value::Object(map)) => map, Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), }; let mut safetensors_files = std::collections::HashSet::new(); for value in weight_map.values() { if let Some(file) = value.as_str() { safetensors_files.insert(file); } } let safetensors_files: Vec<_> = safetensors_files .into_iter() .map(|v| path.join(v)) .collect(); Ok(safetensors_files) }