mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Abstract the gradient storage.
This commit is contained in:
@ -13,9 +13,14 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.3.1"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
|
cudarc = { version = "0.9.9", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
tokenizers = "0.13.3"
|
tokenizers = "0.13.3"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
cuda = ["dep:cudarc"]
|
||||||
|
@ -54,27 +54,36 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Device {
|
impl Device {
|
||||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Storage {
|
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => Storage::Cpu(CpuStorage::ones_impl(shape, dtype)),
|
Device::Cpu => {
|
||||||
|
let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype));
|
||||||
|
Ok(storage)
|
||||||
|
}
|
||||||
Device::Cuda { gpu_id: _ } => {
|
Device::Cuda { gpu_id: _ } => {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
|
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => Storage::Cpu(CpuStorage::zeros_impl(shape, dtype)),
|
Device::Cpu => {
|
||||||
|
let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype));
|
||||||
|
Ok(storage)
|
||||||
|
}
|
||||||
Device::Cuda { gpu_id: _ } => {
|
Device::Cuda { gpu_id: _ } => {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
|
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
|
Device::Cpu => {
|
||||||
|
let storage = Storage::Cpu(array.to_cpu_storage());
|
||||||
|
Ok(storage)
|
||||||
|
}
|
||||||
Device::Cuda { gpu_id: _ } => {
|
Device::Cuda { gpu_id: _ } => {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
@ -86,9 +86,9 @@ impl Tensor {
|
|||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: Device,
|
device: Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Self {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.ones(&shape, dtype);
|
let storage = device.ones(&shape, dtype)?;
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -98,18 +98,18 @@ impl Tensor {
|
|||||||
op: None,
|
op: None,
|
||||||
is_variable,
|
is_variable,
|
||||||
};
|
};
|
||||||
Self(Arc::new(tensor_))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
|
||||||
Self::ones_impl(shape, dtype, device, false)
|
Self::ones_impl(shape, dtype, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
|
||||||
Self::ones_impl(shape, dtype, device, true)
|
Self::ones_impl(shape, dtype, device, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ones_like(&self) -> Self {
|
pub fn ones_like(&self) -> Result<Self> {
|
||||||
Tensor::ones(self.shape(), self.dtype(), self.device())
|
Tensor::ones(self.shape(), self.dtype(), self.device())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,9 +118,9 @@ impl Tensor {
|
|||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: Device,
|
device: Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Self {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.zeros(&shape, dtype);
|
let storage = device.zeros(&shape, dtype)?;
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -130,18 +130,18 @@ impl Tensor {
|
|||||||
op: None,
|
op: None,
|
||||||
is_variable,
|
is_variable,
|
||||||
};
|
};
|
||||||
Self(Arc::new(tensor_))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
|
||||||
Self::zeros_impl(shape, dtype, device, false)
|
Self::zeros_impl(shape, dtype, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
|
||||||
Self::zeros_impl(shape, dtype, device, true)
|
Self::zeros_impl(shape, dtype, device, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn zeros_like(&self) -> Self {
|
pub fn zeros_like(&self) -> Result<Self> {
|
||||||
Tensor::zeros(self.shape(), self.dtype(), self.device())
|
Tensor::zeros(self.shape(), self.dtype(), self.device())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ impl Tensor {
|
|||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = array.shape()?;
|
let shape = array.shape()?;
|
||||||
let storage = device.tensor(array);
|
let storage = device.tensor(array)?;
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -376,16 +376,16 @@ impl Tensor {
|
|||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn backward(&self) -> Result<HashMap<TensorId, Tensor>> {
|
pub fn backward(&self) -> Result<GradStore> {
|
||||||
let sorted_nodes = self.sorted_nodes();
|
let sorted_nodes = self.sorted_nodes();
|
||||||
println!("{}", sorted_nodes.len());
|
println!("{}", sorted_nodes.len());
|
||||||
let mut grads = HashMap::new();
|
let mut grads = GradStore::new();
|
||||||
grads.insert(self.id, self.ones_like());
|
grads.insert(self, self.ones_like()?);
|
||||||
for node in sorted_nodes.iter() {
|
for node in sorted_nodes.iter() {
|
||||||
if node.is_variable {
|
if node.is_variable {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let grad = grads.remove(&node.id).unwrap();
|
let grad = grads.remove(node).unwrap();
|
||||||
// TODO: We should perform all these operations in place (or at least not track the
|
// TODO: We should perform all these operations in place (or at least not track the
|
||||||
// whole graph).
|
// whole graph).
|
||||||
// The only drawback would be if we wanted to support grad of grad but this is out of
|
// The only drawback would be if we wanted to support grad of grad but this is out of
|
||||||
@ -393,51 +393,51 @@ impl Tensor {
|
|||||||
if let Some(op) = &node.op {
|
if let Some(op) = &node.op {
|
||||||
match op {
|
match op {
|
||||||
Op::Add(lhs, rhs) => {
|
Op::Add(lhs, rhs) => {
|
||||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
||||||
}
|
}
|
||||||
Op::Sub(lhs, rhs) => {
|
Op::Sub(lhs, rhs) => {
|
||||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
|
||||||
}
|
}
|
||||||
Op::Mul(lhs, rhs) => {
|
Op::Mul(lhs, rhs) => {
|
||||||
let lhs_grad = grad.mul(rhs)?;
|
let lhs_grad = grad.mul(rhs)?;
|
||||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||||
let rhs_grad = grad.mul(lhs)?;
|
let rhs_grad = grad.mul(lhs)?;
|
||||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Div(lhs, rhs) => {
|
Op::Div(lhs, rhs) => {
|
||||||
let lhs_grad = grad.div(rhs)?;
|
let lhs_grad = grad.div(rhs)?;
|
||||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||||
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Affine { arg, mul, .. } => {
|
Op::Affine { arg, mul, .. } => {
|
||||||
let arg_grad = grad.affine(*mul, 0.)?;
|
let arg_grad = grad.affine(*mul, 0.)?;
|
||||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Neg(arg) => {
|
Op::Neg(arg) => {
|
||||||
let arg_grad = grad.neg()?;
|
let arg_grad = grad.neg()?;
|
||||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Sqr(arg) => {
|
Op::Sqr(arg) => {
|
||||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Sqrt(arg) => {
|
Op::Sqrt(arg) => {
|
||||||
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -503,3 +503,39 @@ bin_trait!(Add, add, |_| 1., |v| v);
|
|||||||
bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
|
bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
|
||||||
bin_trait!(Mul, mul, |v| v, |_| 0.);
|
bin_trait!(Mul, mul, |v| v, |_| 0.);
|
||||||
bin_trait!(Div, div, |v| 1. / v, |_| 0.);
|
bin_trait!(Div, div, |v| 1. / v, |_| 0.);
|
||||||
|
|
||||||
|
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||||
|
|
||||||
|
impl GradStore {
|
||||||
|
fn new() -> Self {
|
||||||
|
GradStore(HashMap::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
||||||
|
self.0.get(&id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
||||||
|
self.0.get(&tensor.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
||||||
|
self.0.remove(&tensor.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
||||||
|
self.0.insert(tensor.id, grad)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
||||||
|
use std::collections::hash_map::Entry;
|
||||||
|
let grad = match self.0.entry(tensor.id) {
|
||||||
|
Entry::Occupied(entry) => entry.into_mut(),
|
||||||
|
Entry::Vacant(entry) => {
|
||||||
|
let grad = tensor.zeros_like()?;
|
||||||
|
entry.insert(grad)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(grad)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -6,7 +6,7 @@ fn simple_grad() -> Result<()> {
|
|||||||
let x = Tensor::var(&[3f32, 1., 4.], Device::Cpu)?;
|
let x = Tensor::var(&[3f32, 1., 4.], Device::Cpu)?;
|
||||||
let y = (((&x * &x)? + &x * 5f64)? + 4f64)?;
|
let y = (((&x * &x)? + &x * 5f64)? + 4f64)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(&x.id()).context("no grad for x")?;
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
assert_eq!(x.to_vec1::<f32>()?, [3., 1., 4.]);
|
assert_eq!(x.to_vec1::<f32>()?, [3., 1., 4.]);
|
||||||
// y = x^2 + 5.x + 4
|
// y = x^2 + 5.x + 4
|
||||||
assert_eq!(y.to_vec1::<f32>()?, [28., 10., 40.]);
|
assert_eq!(y.to_vec1::<f32>()?, [28., 10., 40.]);
|
||||||
|
@ -2,7 +2,7 @@ use candle::{DType, Device, Result, Tensor};
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn zeros() -> Result<()> {
|
fn zeros() -> Result<()> {
|
||||||
let tensor = Tensor::zeros((5, 2), DType::F32, Device::Cpu);
|
let tensor = Tensor::zeros((5, 2), DType::F32, Device::Cpu)?;
|
||||||
let (dim1, dim2) = tensor.shape().r2()?;
|
let (dim1, dim2) = tensor.shape().r2()?;
|
||||||
assert_eq!(dim1, 5);
|
assert_eq!(dim1, 5);
|
||||||
assert_eq!(dim2, 2);
|
assert_eq!(dim2, 2);
|
||||||
|
Reference in New Issue
Block a user