Rename the candle crate to candle-core (#301)

* Rename to candle-core.

* More candle-core renaming.
This commit is contained in:
Laurent Mazare
2023-08-02 08:20:22 +01:00
committed by GitHub
parent 6e33ff62d6
commit 51e51da896
23 changed files with 77 additions and 76 deletions

View File

@ -5,8 +5,8 @@ We will now create the hello world of the ML world, building a model capable of
Open `src/main.rs` and fill in this content: Open `src/main.rs` and fill in this content:
```rust ```rust
# extern crate candle; # extern crate candle_core;
use candle::{DType, Device, Result, Tensor}; use candle_core::{DType, Device, Result, Tensor};
struct Model { struct Model {
first: Tensor, first: Tensor,
@ -49,8 +49,8 @@ Now that we have this, we might want to complexify things a bit, for instance by
the classical `Linear` layer. We can do as such the classical `Linear` layer. We can do as such
```rust ```rust
# extern crate candle; # extern crate candle_core;
# use candle::{DType, Device, Result, Tensor}; # use candle_core::{DType, Device, Result, Tensor};
struct Linear{ struct Linear{
weight: Tensor, weight: Tensor,
bias: Tensor, bias: Tensor,
@ -79,8 +79,8 @@ impl Model {
This will change the model running code into a new function This will change the model running code into a new function
```rust ```rust
# extern crate candle; # extern crate candle_core;
# use candle::{DType, Device, Result, Tensor}; # use candle_core::{DType, Device, Result, Tensor};
# struct Linear{ # struct Linear{
# weight: Tensor, # weight: Tensor,
# bias: Tensor, # bias: Tensor,
@ -144,9 +144,9 @@ cargo add --git https://github.com/LaurentMazare/candle.git candle-nn
And rewrite our examples using it And rewrite our examples using it
```rust ```rust
# extern crate candle; # extern crate candle_core;
# extern crate candle_nn; # extern crate candle_nn;
use candle::{DType, Device, Result, Tensor}; use candle_core::{DType, Device, Result, Tensor};
use candle_nn::Linear; use candle_nn::Linear;
struct Model { struct Model {

View File

@ -1,5 +1,5 @@
[package] [package]
name = "candle" name = "candle-core"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"

View File

@ -2,7 +2,7 @@
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::Result; use anyhow::Result;
use candle::{Device, Tensor}; use candle_core::{Device, Tensor};
fn main() -> Result<()> { fn main() -> Result<()> {
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?; let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;

View File

@ -2,7 +2,7 @@
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::Result; use anyhow::Result;
use candle::{Device, Tensor}; use candle_core::{Device, Tensor};
fn main() -> Result<()> { fn main() -> Result<()> {
let device = Device::new_cuda(0)?; let device = Device::new_cuda(0)?;

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
use std::str::FromStr; use std::str::FromStr;
use anyhow::Result; use anyhow::Result;
use candle::{Device, Tensor}; use candle_core::{Device, Tensor};
fn cos_sin(n: usize, device: &Device) -> Result<Tensor> { fn cos_sin(n: usize, device: &Device) -> Result<Tensor> {
let thetas: Vec<_> = (0..n).map(|i| (i as f32 / n as f32)).collect(); let thetas: Vec<_> = (0..n).map(|i| (i as f32 / n as f32)).collect();

View File

@ -7,7 +7,7 @@ impl Tensor {
/// Intended to be use by the trait `.i()` /// Intended to be use by the trait `.i()`
/// ///
/// ``` /// ```
/// # use candle::{Tensor, DType, Device, IndexOp}; /// # use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// ///
/// let c = a.i(0..1)?; /// let c = a.i(0..1)?;
@ -22,7 +22,7 @@ impl Tensor {
/// let c = a.i((.., ..=2))?; /// let c = a.i((.., ..=2))?;
/// assert_eq!(c.shape().dims(), &[2, 3]); /// assert_eq!(c.shape().dims(), &[2, 3]);
/// ///
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> { fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> {
let mut x = self.clone(); let mut x = self.clone();

View File

@ -1,8 +1,8 @@
//! ML framework for Rust //! ML framework for Rust
//! //!
//! ```rust //! ```rust
//! use candle::{Tensor, DType, Device}; //! use candle_core::{Tensor, DType, Device};
//! # use candle::Error; //! # use candle_core::Error;
//! # fn main() -> Result<(), Error>{ //! # fn main() -> Result<(), Error>{
//! //!
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?; //! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;

View File

@ -54,13 +54,13 @@ impl AsRef<Tensor> for Tensor {
/// The core struct for manipulating tensors. /// The core struct for manipulating tensors.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, DType, Device}; /// use candle_core::{Tensor, DType, Device};
/// ///
/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?; /// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?; /// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
/// ///
/// let c = a.matmul(&b)?; /// let c = a.matmul(&b)?;
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
/// ///
/// Tensors are reference counted with [`Arc`] so cloning them is cheap. /// Tensors are reference counted with [`Arc`] so cloning them is cheap.
@ -163,11 +163,11 @@ impl Tensor {
/// Creates a new tensor filled with ones. /// Creates a new tensor filled with ones.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, DType, Device}; /// use candle_core::{Tensor, DType, Device};
/// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?; /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
/// // a == b /// // a == b
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, false) Self::ones_impl(shape, dtype, device, false)
@ -176,11 +176,11 @@ impl Tensor {
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor. /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, DType, Device}; /// use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = a.ones_like()?; /// let b = a.ones_like()?;
/// // b == a + 1 /// // b == a + 1
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn ones_like(&self) -> Result<Self> { pub fn ones_like(&self) -> Result<Self> {
Tensor::ones(self.shape(), self.dtype(), self.device()) Tensor::ones(self.shape(), self.dtype(), self.device())
@ -208,11 +208,11 @@ impl Tensor {
/// Creates a new tensor filled with zeros. /// Creates a new tensor filled with zeros.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, DType, Device}; /// use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
/// // a == b /// // a == b
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, false) Self::zeros_impl(shape, dtype, device, false)
@ -222,11 +222,11 @@ impl Tensor {
/// tensor. /// tensor.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, DType, Device}; /// use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = a.zeros_like()?; /// let b = a.zeros_like()?;
/// // b is on CPU f32. /// // b is on CPU f32.
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn zeros_like(&self) -> Result<Self> { pub fn zeros_like(&self) -> Result<Self> {
Tensor::zeros(self.shape(), self.dtype(), self.device()) Tensor::zeros(self.shape(), self.dtype(), self.device())
@ -516,11 +516,11 @@ impl Tensor {
/// be performed. /// be performed.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, Device}; /// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
/// let a = a.affine(4., -2.)?; /// let a = a.affine(4., -2.)?;
/// assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]); /// assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> { pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
let storage = self.storage().affine(self.layout(), mul, add)?; let storage = self.storage().affine(self.layout(), mul, add)?;
@ -642,7 +642,7 @@ impl Tensor {
/// that the number of elements for each dimension index in `sum_dims` is 1. /// that the number of elements for each dimension index in `sum_dims` is 1.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, Device}; /// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
/// let s = a.sum_keepdim(0)?; /// let s = a.sum_keepdim(0)?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]); /// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
@ -650,7 +650,7 @@ impl Tensor {
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]); /// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
/// let s = a.sum_keepdim((0, 1))?; /// let s = a.sum_keepdim((0, 1))?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]); /// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> { pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
self.sum_impl(sum_dims, true) self.sum_impl(sum_dims, true)
@ -854,12 +854,12 @@ impl Tensor {
/// vocabulary size, and `h` the hidden size. /// vocabulary size, and `h` the hidden size.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, Device}; /// use candle_core::{Tensor, Device};
/// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?; /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
/// let emb = values.embedding(&ids)?; /// let emb = values.embedding(&ids)?;
/// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]); /// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn embedding(&self, ids: &Self) -> Result<Self> { pub fn embedding(&self, ids: &Self) -> Result<Self> {
if self.rank() != 2 || ids.rank() != 1 { if self.rank() != 2 || ids.rank() != 1 {
@ -1191,11 +1191,11 @@ impl Tensor {
/// scalar with zero dimensions. /// scalar with zero dimensions.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, Device}; /// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.sum_all()?; /// let tensor = tensor.sum_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 15.); /// assert_eq!(tensor.to_scalar::<f32>()?, 15.);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn sum_all(&self) -> Result<Tensor> { pub fn sum_all(&self) -> Result<Tensor> {
let dims: Vec<_> = (0..self.rank()).collect(); let dims: Vec<_> = (0..self.rank()).collect();
@ -1252,11 +1252,11 @@ impl Tensor {
/// Flattens the input tensor by reshaping it into a one dimension tensor. /// Flattens the input tensor by reshaping it into a one dimension tensor.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, Device}; /// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.flatten_all()?; /// let tensor = tensor.flatten_all()?;
/// assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]); /// assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn flatten_all(&self) -> Result<Tensor> { pub fn flatten_all(&self) -> Result<Tensor> {
self.flatten_(None::<usize>, None::<usize>) self.flatten_(None::<usize>, None::<usize>)
@ -1265,13 +1265,13 @@ impl Tensor {
/// Returns the sub-tensor fixing the index at `i` on the first dimension. /// Returns the sub-tensor fixing the index at `i` on the first dimension.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, Device}; /// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let t = tensor.get(0)?; /// let t = tensor.get(0)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]); /// assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]);
/// let t = tensor.get(1)?; /// let t = tensor.get(1)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]); /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn get(&self, i: usize) -> Result<Tensor> { pub fn get(&self, i: usize) -> Result<Tensor> {
let dims = self.dims(); let dims = self.dims();
@ -1286,11 +1286,11 @@ impl Tensor {
/// input are swapped. /// input are swapped.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, Device}; /// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.t()?; /// let tensor = tensor.t()?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]); /// assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn t(&self) -> Result<Tensor> { pub fn t(&self) -> Result<Tensor> {
let rank = self.rank(); let rank = self.rank();
@ -1433,12 +1433,12 @@ impl Tensor {
/// Casts the input tensor to the target `dtype`. /// Casts the input tensor to the target `dtype`.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, Device}; /// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?; /// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?;
/// assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979); /// assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979);
/// let tensor = tensor.to_dtype(candle::DType::F32)?; /// let tensor = tensor.to_dtype(candle_core::DType::F32)?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927); /// assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn to_dtype(&self, dtype: DType) -> Result<Self> { pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
if self.dtype() == dtype { if self.dtype() == dtype {
@ -1483,7 +1483,7 @@ impl Tensor {
/// a new storage and copies the data over, the returned tensor is always contiguous. /// a new storage and copies the data over, the returned tensor is always contiguous.
/// ///
/// ```rust /// ```rust
/// # use candle::{Tensor, DType, Device, D}; /// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// ///
/// let c = a.reshape((1, 6))?; /// let c = a.reshape((1, 6))?;
@ -1491,7 +1491,7 @@ impl Tensor {
/// ///
/// let c = a.reshape((3, 2))?; /// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]); /// assert_eq!(c.shape().dims(), &[3, 2]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> { pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
let shape = shape.into(); let shape = shape.into();
@ -1526,7 +1526,7 @@ impl Tensor {
/// Creates a new tensor with the specified dimension removed if its size was one. /// Creates a new tensor with the specified dimension removed if its size was one.
/// ///
/// ```rust /// ```rust
/// # use candle::{Tensor, DType, Device, D}; /// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
/// ///
/// let c = a.squeeze(2)?; /// let c = a.squeeze(2)?;
@ -1534,7 +1534,7 @@ impl Tensor {
/// ///
/// let c = a.squeeze(D::Minus1)?; /// let c = a.squeeze(D::Minus1)?;
/// assert_eq!(c.shape().dims(), &[2, 3]); /// assert_eq!(c.shape().dims(), &[2, 3]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> { pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
// The PyTorch semantics are to return the same tensor if the target dimension // The PyTorch semantics are to return the same tensor if the target dimension
@ -1553,7 +1553,7 @@ impl Tensor {
/// Creates a new tensor with a dimension of size one inserted at the specified position. /// Creates a new tensor with a dimension of size one inserted at the specified position.
/// ///
/// ```rust /// ```rust
/// # use candle::{Tensor, DType, Device, D}; /// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// ///
/// let c = a.unsqueeze(0)?; /// let c = a.unsqueeze(0)?;
@ -1561,7 +1561,7 @@ impl Tensor {
/// ///
/// let c = a.unsqueeze(D::Minus1)?; /// let c = a.unsqueeze(D::Minus1)?;
/// assert_eq!(c.shape().dims(), &[2, 3, 1]); /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> { pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
let mut dims = self.dims().to_vec(); let mut dims = self.dims().to_vec();
@ -1576,7 +1576,7 @@ impl Tensor {
/// All tensors must have the same rank, and the output has one additional rank /// All tensors must have the same rank, and the output has one additional rank
/// ///
/// ```rust /// ```rust
/// # use candle::{Tensor, DType, Device}; /// # use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// ///
@ -1585,7 +1585,7 @@ impl Tensor {
/// ///
/// let c = Tensor::stack(&[&a, &b], 2)?; /// let c = Tensor::stack(&[&a, &b], 2)?;
/// assert_eq!(c.shape().dims(), &[2, 3, 2]); /// assert_eq!(c.shape().dims(), &[2, 3, 2]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
@ -1605,7 +1605,7 @@ impl Tensor {
/// the same rank /// the same rank
/// ///
/// ```rust /// ```rust
/// # use candle::{Tensor, DType, Device}; /// # use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// ///
@ -1614,7 +1614,7 @@ impl Tensor {
/// ///
/// let c = Tensor::cat(&[&a, &b], 1)?; /// let c = Tensor::cat(&[&a, &b], 1)?;
/// assert_eq!(c.shape().dims(), &[2, 6]); /// assert_eq!(c.shape().dims(), &[2, 6]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() { if args.is_empty() {

View File

@ -1,6 +1,6 @@
use candle::backend::BackendStorage; use candle_core::backend::BackendStorage;
use candle::cpu_backend; use candle_core::cpu_backend;
use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; use candle_core::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
mod test_utils; mod test_utils;
use test_utils::to_vec1_round; use test_utils::to_vec1_round;
@ -24,7 +24,7 @@ impl CustomOp1 for Elu {
} }
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
let storage = candle::map_dtype!( let storage = candle_core::map_dtype!(
"elu", "elu",
s, s,
|s| cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)), |s| cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)),
@ -67,7 +67,7 @@ impl CustomOp1 for EluBackward {
} }
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
let storage = candle::map_dtype!( let storage = candle_core::map_dtype!(
"elu-bwd", "elu-bwd",
s, s,
|s| cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)), |s| cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)),
@ -104,7 +104,7 @@ impl CustomOp1 for EluWithBackward {
#[test] #[test]
fn custom_op1_with_backward() -> Result<()> { fn custom_op1_with_backward() -> Result<()> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let t = candle::Var::new(&[-2f32, 0f32, 2f32], cpu)?; let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
let elu_t = t.custom_op1(EluWithBackward::new(2.))?; let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]); assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);

View File

@ -1,5 +1,5 @@
use anyhow::Result; use anyhow::Result;
use candle::{DType, Device::Cpu, Tensor}; use candle_core::{DType, Device::Cpu, Tensor};
#[test] #[test]
fn display_scalar() -> Result<()> { fn display_scalar() -> Result<()> {

View File

@ -1,5 +1,5 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use candle::{Device, Shape, Tensor, Var}; use candle_core::{Device, Shape, Tensor, Var};
mod test_utils; mod test_utils;
fn simple_grad(device: &Device) -> Result<()> { fn simple_grad(device: &Device) -> Result<()> {

View File

@ -1,5 +1,5 @@
use anyhow::Result; use anyhow::Result;
use candle::{Device, IndexOp, Tensor}; use candle_core::{Device, IndexOp, Tensor};
mod test_utils; mod test_utils;

View File

@ -1,5 +1,6 @@
mod test_utils; mod test_utils;
use candle::{Device, IndexOp, Result, Tensor}; use candle::{Device, IndexOp, Result, Tensor};
use candle_core as candle;
fn contiguous(device: &Device) -> Result<()> { fn contiguous(device: &Device) -> Result<()> {
let tensor = Tensor::arange(0u32, 24u32, device)?.reshape((2, 3, 4))?; let tensor = Tensor::arange(0u32, 24u32, device)?.reshape((2, 3, 4))?;

View File

@ -1,5 +1,5 @@
mod test_utils; mod test_utils;
use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_core::{DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> { fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?; let tensor = Tensor::zeros((5, 2), DType::F32, device)?;

View File

@ -1,6 +1,6 @@
#![allow(dead_code)] #![allow(dead_code)]
use candle::{Result, Tensor}; use candle_core::{Result, Tensor};
#[macro_export] #[macro_export]
macro_rules! test_device { macro_rules! test_device {

View File

@ -11,7 +11,7 @@ license = "MIT/Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core" } candle = { path = "../candle-core", package = "candle-core" }
candle-nn = { path = "../candle-nn" } candle-nn = { path = "../candle-nn" }
candle-transformers = { path = "../candle-transformers" } candle-transformers = { path = "../candle-transformers" }
candle-flash-attn = { path = "../candle-flash-attn", optional = true } candle-flash-attn = { path = "../candle-flash-attn", optional = true }

View File

@ -106,7 +106,7 @@ impl TensorParallelRowLinear {
let rank = comm.rank(); let rank = comm.rank();
let size = comm.world_size(); let size = comm.world_size();
let weight = vb.get_sharded("weight", 1, rank, size)?; let weight = vb.get_sharded("weight", 1, rank, size)?;
Ok(Self::new(Linear::new(weight, None), comm.clone())) Ok(Self::new(Linear::new(weight, None), comm))
} }
} }
@ -296,8 +296,8 @@ impl CausalSelfAttention {
let k = k.transpose(1, 2)?; let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?; let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
let y = let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?; .transpose(1, 2)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?; let y = self.o_proj.forward(&y)?;
@ -363,7 +363,7 @@ impl Mlp {
fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> { fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
Ok(Self::new(c_fc1, c_fc2, c_proj)) Ok(Self::new(c_fc1, c_fc2, c_proj))
} }
} }
@ -396,7 +396,7 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?;
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
let post_attention_layernorm = let post_attention_layernorm =
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?; RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;

View File

@ -11,7 +11,7 @@ license = "MIT/Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", features = ["cuda"] } candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies] [build-dependencies]

View File

@ -11,7 +11,7 @@ license = "MIT/Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core" } candle = { path = "../candle-core", package = "candle-core" }
thiserror = { workspace = true } thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
safetensors = { workspace = true } safetensors = { workspace = true }

View File

@ -16,7 +16,7 @@ crate-type = ["cdylib"]
doc = false doc = false
[dependencies] [dependencies]
candle = { path = "../candle-core" } candle = { path = "../candle-core", package = "candle-core" }
pyo3 = { version = "0.19.0", features = ["extension-module"] } pyo3 = { version = "0.19.0", features = ["extension-module"] }
half = { workspace = true } half = { workspace = true }

View File

@ -11,7 +11,7 @@ license = "MIT/Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core" } candle = { path = "../candle-core", package = "candle-core" }
hf-hub = { workspace = true} hf-hub = { workspace = true}
candle-nn = { path = "../candle-nn" } candle-nn = { path = "../candle-nn" }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }

View File

@ -11,7 +11,7 @@ license = "MIT/Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../../candle-core" } candle = { path = "../../candle-core", package = "candle-core" }
candle-nn = { path = "../../candle-nn" } candle-nn = { path = "../../candle-nn" }
num-traits = { workspace = true } num-traits = { workspace = true }

View File

@ -11,7 +11,7 @@ license = "MIT/Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../../candle-core" } candle = { path = "../../candle-core", package = "candle-core" }
candle-nn = { path = "../../candle-nn" } candle-nn = { path = "../../candle-nn" }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }