Detach the tensors on batch-norm eval. (#1702)

* Detach the tensors on batch-norm eval.

* Fix pyo3 bindings.

* Black tweak.

* Formatting.

* Also update the pyo3-onnx formatting.

* Apply black.
This commit is contained in:
Laurent Mazare
2024-02-13 14:26:32 +01:00
committed by GitHub
parent 13c67226e6
commit ad73e93da2
14 changed files with 117 additions and 27 deletions

View File

@ -175,7 +175,7 @@ impl Tensor {
// the backprop graph of the backprop itself. This would be an issue for second order // the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment. // derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b); let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach()? }; let grad = if do_not_detach { grad } else { grad.detach() };
if let Some(op) = node.op() { if let Some(op) = node.op() {
match op { match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => { Op::Binary(lhs, rhs, BinaryOp::Add) => {

View File

@ -1882,9 +1882,9 @@ impl Tensor {
/// this new node. The storage of this tensor is shared with the initial tensor. /// this new node. The storage of this tensor is shared with the initial tensor.
/// ///
/// If the tensor is already detached from the computation graph, the same tensor is returned. /// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> { pub fn detach(&self) -> Tensor {
if self.op.is_none() && !self.is_variable { if self.op.is_none() && !self.is_variable {
Ok(self.clone()) self.clone()
} else { } else {
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
@ -1895,7 +1895,7 @@ impl Tensor {
dtype: self.dtype, dtype: self.dtype,
device: self.device.clone(), device: self.device.clone(),
}; };
Ok(Tensor(Arc::new(tensor_))) Tensor(Arc::new(tensor_))
} }
} }

View File

@ -107,6 +107,10 @@ impl Var {
Ok(Self(inner)) Ok(Self(inner))
} }
pub fn as_detached_tensor(&self) -> Tensor {
self.0.detach()
}
pub fn as_tensor(&self) -> &Tensor { pub fn as_tensor(&self) -> &Tensor {
&self.0 &self.0
} }

View File

@ -411,7 +411,7 @@ impl DDPG<'_> {
pub fn actions(&mut self, state: &Tensor) -> Result<f32> { pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self let actions = self
.actor .actor
.forward(&state.detach()?.unsqueeze(0)?)? .forward(&state.detach().unsqueeze(0)?)?
.squeeze(0)?; .squeeze(0)?;
let actions = if self.train { let actions = if self.train {
(actions + self.ou_noise.sample()?)? (actions + self.ou_noise.sample()?)?

View File

@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
loop { loop {
let action = { let action = {
let action_probs: Vec<f32> = let action_probs: Vec<f32> =
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)? softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
.squeeze(0)? .squeeze(0)?
.to_vec1()?; .to_vec1()?;
weighted_sample(action_probs, &mut rng)? as i64 weighted_sample(action_probs, &mut rng)? as i64
@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)? let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.detach()?; .detach();
let actions_mask = { let actions_mask = {
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect(); let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
.unwrap() .unwrap()
}) })
.collect(); .collect();
Tensor::stack(&actions_mask, 0)?.detach()? Tensor::stack(&actions_mask, 0)?.detach()
}; };
let states = { let states = {
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect(); let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
Tensor::stack(&states, 0)?.detach()? Tensor::stack(&states, 0)?.detach()
}; };
let log_probs = actions_mask let log_probs = actions_mask

View File

@ -262,9 +262,19 @@ impl BatchNorm {
let target_shape = target_shape.as_slice(); let target_shape = target_shape.as_slice();
let x = x let x = x
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)? .broadcast_sub(
&self
.running_mean
.as_detached_tensor()
.reshape(target_shape)?,
)?
.broadcast_div( .broadcast_div(
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?, &(self
.running_var
.as_detached_tensor()
.reshape(target_shape)?
+ self.eps)?
.sqrt()?,
)?; )?;
match &self.weight_and_bias { match &self.weight_and_bias {

View File

@ -88,23 +88,27 @@ class QTensor:
Dequantizes the tensor. Dequantizes the tensor.
""" """
pass pass
@property @property
def ggml_dtype(self) -> str: def ggml_dtype(self) -> str:
""" """
Gets the tensors quantized dtype. Gets the tensors quantized dtype.
""" """
pass pass
def matmul_t(self, lhs: Tensor) -> Tensor: def matmul_t(self, lhs: Tensor) -> Tensor:
""" """
Performs a quantized matrix multiplication, with the quantized tensor as the right hand side. Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
""" """
pass pass
@property @property
def rank(self) -> int: def rank(self) -> int:
""" """
Gets the rank of the tensor. Gets the rank of the tensor.
""" """
pass pass
@property @property
def shape(self) -> Tuple[int]: def shape(self) -> Tuple[int]:
""" """
@ -119,178 +123,213 @@ class Tensor:
def __init__(self, data: _ArrayLike): def __init__(self, data: _ArrayLike):
pass pass
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Add a scalar to a tensor or two tensors together. Add a scalar to a tensor or two tensors together.
""" """
pass pass
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor": def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
""" """
Return a slice of a tensor. Return a slice of a tensor.
""" """
pass pass
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Multiply a tensor by a scalar or one tensor by another. Multiply a tensor by a scalar or one tensor by another.
""" """
pass pass
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Add a scalar to a tensor or two tensors together. Add a scalar to a tensor or two tensors together.
""" """
pass pass
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor": def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Multiply a tensor by a scalar or one tensor by another. Multiply a tensor by a scalar or one tensor by another.
""" """
pass pass
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Subtract a scalar from a tensor or one tensor from another. Subtract a scalar from a tensor or one tensor from another.
""" """
pass pass
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Divide a tensor by a scalar or one tensor by another. Divide a tensor by a scalar or one tensor by another.
""" """
pass pass
def abs(self) -> Tensor: def abs(self) -> Tensor:
""" """
Performs the `abs` operation on the tensor. Performs the `abs` operation on the tensor.
""" """
pass pass
def argmax_keepdim(self, dim: int) -> Tensor: def argmax_keepdim(self, dim: int) -> Tensor:
""" """
Returns the indices of the maximum value(s) across the selected dimension. Returns the indices of the maximum value(s) across the selected dimension.
""" """
pass pass
def argmin_keepdim(self, dim: int) -> Tensor: def argmin_keepdim(self, dim: int) -> Tensor:
""" """
Returns the indices of the minimum value(s) across the selected dimension. Returns the indices of the minimum value(s) across the selected dimension.
""" """
pass pass
def broadcast_add(self, rhs: Tensor) -> Tensor: def broadcast_add(self, rhs: Tensor) -> Tensor:
""" """
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def broadcast_as(self, *shape: Shape) -> Tensor: def broadcast_as(self, *shape: Shape) -> Tensor:
""" """
Broadcasts the tensor to the given shape. Broadcasts the tensor to the given shape.
""" """
pass pass
def broadcast_div(self, rhs: Tensor) -> Tensor: def broadcast_div(self, rhs: Tensor) -> Tensor:
""" """
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def broadcast_left(self, *shape: Shape) -> Tensor: def broadcast_left(self, *shape: Shape) -> Tensor:
""" """
Broadcasts the tensor to the given shape, adding new dimensions on the left. Broadcasts the tensor to the given shape, adding new dimensions on the left.
""" """
pass pass
def broadcast_mul(self, rhs: Tensor) -> Tensor: def broadcast_mul(self, rhs: Tensor) -> Tensor:
""" """
Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def broadcast_sub(self, rhs: Tensor) -> Tensor: def broadcast_sub(self, rhs: Tensor) -> Tensor:
""" """
Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def contiguous(self) -> Tensor: def contiguous(self) -> Tensor:
""" """
Makes the tensor contiguous in memory. Makes the tensor contiguous in memory.
""" """
pass pass
def copy(self) -> Tensor: def copy(self) -> Tensor:
""" """
Returns a copy of the tensor. Returns a copy of the tensor.
""" """
pass pass
def cos(self) -> Tensor: def cos(self) -> Tensor:
""" """
Performs the `cos` operation on the tensor. Performs the `cos` operation on the tensor.
""" """
pass pass
def detach(self) -> Tensor: def detach(self) -> Tensor:
""" """
Detach the tensor from the computation graph. Detach the tensor from the computation graph.
""" """
pass pass
@property @property
def device(self) -> Device: def device(self) -> Device:
""" """
Gets the tensor's device. Gets the tensor's device.
""" """
pass pass
@property @property
def dtype(self) -> DType: def dtype(self) -> DType:
""" """
Gets the tensor's dtype. Gets the tensor's dtype.
""" """
pass pass
def exp(self) -> Tensor: def exp(self) -> Tensor:
""" """
Performs the `exp` operation on the tensor. Performs the `exp` operation on the tensor.
""" """
pass pass
def flatten_all(self) -> Tensor: def flatten_all(self) -> Tensor:
""" """
Flattens the tensor into a 1D tensor. Flattens the tensor into a 1D tensor.
""" """
pass pass
def flatten_from(self, dim: int) -> Tensor: def flatten_from(self, dim: int) -> Tensor:
""" """
Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension. Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
""" """
pass pass
def flatten_to(self, dim: int) -> Tensor: def flatten_to(self, dim: int) -> Tensor:
""" """
Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive). Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
""" """
pass pass
def get(self, index: int) -> Tensor: def get(self, index: int) -> Tensor:
""" """
Gets the value at the specified index. Gets the value at the specified index.
""" """
pass pass
def index_select(self, rhs: Tensor, dim: int) -> Tensor: def index_select(self, rhs: Tensor, dim: int) -> Tensor:
""" """
Select values for the input tensor at the target indexes across the specified dimension. Select values for the input tensor at the target indexes across the specified dimension.
@ -302,161 +341,192 @@ class Tensor:
tensor. tensor.
""" """
pass pass
def is_contiguous(self) -> bool: def is_contiguous(self) -> bool:
""" """
Returns true if the tensor is contiguous in C order. Returns true if the tensor is contiguous in C order.
""" """
pass pass
def is_fortran_contiguous(self) -> bool: def is_fortran_contiguous(self) -> bool:
""" """
Returns true if the tensor is contiguous in Fortran order. Returns true if the tensor is contiguous in Fortran order.
""" """
pass pass
def log(self) -> Tensor: def log(self) -> Tensor:
""" """
Performs the `log` operation on the tensor. Performs the `log` operation on the tensor.
""" """
pass pass
def matmul(self, rhs: Tensor) -> Tensor: def matmul(self, rhs: Tensor) -> Tensor:
""" """
Performs a matrix multiplication between the two tensors. Performs a matrix multiplication between the two tensors.
""" """
pass pass
def max_keepdim(self, dim: int) -> Tensor: def max_keepdim(self, dim: int) -> Tensor:
""" """
Gathers the maximum value across the selected dimension. Gathers the maximum value across the selected dimension.
""" """
pass pass
def mean_all(self) -> Tensor: def mean_all(self) -> Tensor:
""" """
Returns the mean of the tensor. Returns the mean of the tensor.
""" """
pass pass
def min_keepdim(self, dim: int) -> Tensor: def min_keepdim(self, dim: int) -> Tensor:
""" """
Gathers the minimum value across the selected dimension. Gathers the minimum value across the selected dimension.
""" """
pass pass
def narrow(self, dim: int, start: int, len: int) -> Tensor: def narrow(self, dim: int, start: int, len: int) -> Tensor:
""" """
Returns a new tensor that is a narrowed version of the input, the dimension `dim` Returns a new tensor that is a narrowed version of the input, the dimension `dim`
ranges from `start` to `start + len`. ranges from `start` to `start + len`.
""" """
pass pass
@property @property
def nelement(self) -> int: def nelement(self) -> int:
""" """
Gets the tensor's element count. Gets the tensor's element count.
""" """
pass pass
def powf(self, p: float) -> Tensor: def powf(self, p: float) -> Tensor:
""" """
Performs the `pow` operation on the tensor with the given exponent. Performs the `pow` operation on the tensor with the given exponent.
""" """
pass pass
def quantize(self, quantized_dtype: str) -> QTensor: def quantize(self, quantized_dtype: str) -> QTensor:
""" """
Quantize the tensor. Quantize the tensor.
""" """
pass pass
@property @property
def rank(self) -> int: def rank(self) -> int:
""" """
Gets the tensor's rank. Gets the tensor's rank.
""" """
pass pass
def recip(self) -> Tensor: def recip(self) -> Tensor:
""" """
Get the `recip` of the tensor. Get the `recip` of the tensor.
""" """
pass pass
def reshape(self, *shape: Shape) -> Tensor: def reshape(self, *shape: Shape) -> Tensor:
""" """
Reshapes the tensor to the given shape. Reshapes the tensor to the given shape.
""" """
pass pass
@property @property
def shape(self) -> Tuple[int]: def shape(self) -> Tuple[int]:
""" """
Gets the tensor's shape. Gets the tensor's shape.
""" """
pass pass
def sin(self) -> Tensor: def sin(self) -> Tensor:
""" """
Performs the `sin` operation on the tensor. Performs the `sin` operation on the tensor.
""" """
pass pass
def sqr(self) -> Tensor: def sqr(self) -> Tensor:
""" """
Squares the tensor. Squares the tensor.
""" """
pass pass
def sqrt(self) -> Tensor: def sqrt(self) -> Tensor:
""" """
Calculates the square root of the tensor. Calculates the square root of the tensor.
""" """
pass pass
def squeeze(self, dim: int) -> Tensor: def squeeze(self, dim: int) -> Tensor:
""" """
Creates a new tensor with the specified dimension removed if its size was one. Creates a new tensor with the specified dimension removed if its size was one.
""" """
pass pass
@property @property
def stride(self) -> Tuple[int]: def stride(self) -> Tuple[int]:
""" """
Gets the tensor's strides. Gets the tensor's strides.
""" """
pass pass
def sum_all(self) -> Tensor: def sum_all(self) -> Tensor:
""" """
Returns the sum of the tensor. Returns the sum of the tensor.
""" """
pass pass
def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor: def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor:
""" """
Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions. Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
""" """
pass pass
def t(self) -> Tensor: def t(self) -> Tensor:
""" """
Transposes the tensor. Transposes the tensor.
""" """
pass pass
def to(self, *args, **kwargs) -> Tensor: def to(self, *args, **kwargs) -> Tensor:
""" """
Performs Tensor dtype and/or device conversion. Performs Tensor dtype and/or device conversion.
""" """
pass pass
def to_device(self, device: Union[str, Device]) -> Tensor: def to_device(self, device: Union[str, Device]) -> Tensor:
""" """
Move the tensor to a new device. Move the tensor to a new device.
""" """
pass pass
def to_dtype(self, dtype: Union[str, DType]) -> Tensor: def to_dtype(self, dtype: Union[str, DType]) -> Tensor:
""" """
Convert the tensor to a new dtype. Convert the tensor to a new dtype.
""" """
pass pass
def to_torch(self) -> torch.Tensor: def to_torch(self) -> torch.Tensor:
""" """
Converts candle's tensor to pytorch's tensor Converts candle's tensor to pytorch's tensor
""" """
pass pass
def transpose(self, dim1: int, dim2: int) -> Tensor: def transpose(self, dim1: int, dim2: int) -> Tensor:
""" """
Returns a tensor that is a transposed version of the input, the given dimensions are swapped. Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
""" """
pass pass
def unsqueeze(self, dim: int) -> Tensor: def unsqueeze(self, dim: int) -> Tensor:
""" """
Creates a new tensor with a dimension of size one inserted at the specified position. Creates a new tensor with a dimension of size one inserted at the specified position.
""" """
pass pass
def values(self) -> _ArrayLike: def values(self) -> _ArrayLike:
""" """
Gets the tensor's data as a Python scalar or array-like object. Gets the tensor's data as a Python scalar or array-like object.
""" """
pass pass
def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor: def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor:
""" """
Returns a tensor with the same shape as the input tensor, the values are taken from Returns a tensor with the same shape as the input tensor, the values are taken from

View File

@ -57,12 +57,10 @@ class Sequential(Module):
_modules: Dict[str, Module] # type: ignore[assignment] _modules: Dict[str, Module] # type: ignore[assignment]
@overload @overload
def __init__(self, *args: Module) -> None: def __init__(self, *args: Module) -> None: ...
...
@overload @overload
def __init__(self, arg: "OrderedDict[str, Module]") -> None: def __init__(self, arg: "OrderedDict[str, Module]") -> None: ...
...
def __init__(self, *args): def __init__(self, *args):
super().__init__() super().__init__()

View File

@ -204,12 +204,10 @@ class Module:
T_destination = TypeVar("T_destination", bound=Dict[str, Any]) T_destination = TypeVar("T_destination", bound=Dict[str, Any])
@overload @overload
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ...
...
@overload @overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ...
...
def state_dict(self, *args, destination=None, prefix="", keep_vars=False): def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
r"""Returns a dictionary containing references to the whole state of the module. r"""Returns a dictionary containing references to the whole state of the module.
@ -586,12 +584,10 @@ class Module:
self: T, self: T,
device: str = ..., device: str = ...,
dtype: Optional[Union[DType, str]] = ..., dtype: Optional[Union[DType, str]] = ...,
) -> T: ) -> T: ...
...
@overload @overload
def to(self: T, dtype: Union[DType, str]) -> T: def to(self: T, dtype: Union[DType, str]) -> T: ...
...
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers. r"""Moves and/or casts the parameters and buffers.

View File

@ -14,6 +14,7 @@ class LayerNorm(Module):
math:: math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
""" """
__constants__ = ["normalized_shape", "eps"] __constants__ = ["normalized_shape", "eps"]
normalized_shape: Tuple[int, ...] normalized_shape: Tuple[int, ...]
eps: float eps: float

View File

@ -11,59 +11,69 @@ class ONNXModel:
def __init__(self, path: str): def __init__(self, path: str):
pass pass
@property @property
def doc_string(self) -> str: def doc_string(self) -> str:
""" """
The doc string of the model. The doc string of the model.
""" """
pass pass
@property @property
def domain(self) -> str: def domain(self) -> str:
""" """
The domain of the operator set of the model. The domain of the operator set of the model.
""" """
pass pass
def initializers(self) -> Dict[str, Tensor]: def initializers(self) -> Dict[str, Tensor]:
""" """
Get the weights of the model. Get the weights of the model.
""" """
pass pass
@property @property
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
""" """
The inputs of the model. The inputs of the model.
""" """
pass pass
@property @property
def ir_version(self) -> int: def ir_version(self) -> int:
""" """
The version of the IR this model targets. The version of the IR this model targets.
""" """
pass pass
@property @property
def model_version(self) -> int: def model_version(self) -> int:
""" """
The version of the model. The version of the model.
""" """
pass pass
@property @property
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
""" """
The outputs of the model. The outputs of the model.
""" """
pass pass
@property @property
def producer_name(self) -> str: def producer_name(self) -> str:
""" """
The producer of the model. The producer of the model.
""" """
pass pass
@property @property
def producer_version(self) -> str: def producer_version(self) -> str:
""" """
The version of the producer of the model. The version of the producer of the model.
""" """
pass pass
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
""" """
Run the model on the given inputs. Run the model on the given inputs.
@ -81,6 +91,7 @@ class ONNXTensorDescription:
The data type of the tensor. The data type of the tensor.
""" """
pass pass
@property @property
def shape(self) -> Tuple[Union[int, str, Any]]: def shape(self) -> Tuple[Union[int, str, Any]]:
""" """

View File

@ -938,8 +938,8 @@ impl PyTensor {
/// Detach the tensor from the computation graph. /// Detach the tensor from the computation graph.
/// &RETURNS&: Tensor /// &RETURNS&: Tensor
fn detach(&self) -> PyResult<Self> { fn detach(&self) -> Self {
Ok(PyTensor(self.0.detach().map_err(wrap_err)?)) PyTensor(self.0.detach())
} }
/// Returns a copy of the tensor. /// Returns a copy of the tensor.

View File

@ -189,7 +189,6 @@ def do_black(content, is_pyi):
line_length=119, line_length=119,
is_pyi=is_pyi, is_pyi=is_pyi,
string_normalization=True, string_normalization=True,
experimental_string_processing=False,
) )
try: try:
return black.format_file_contents(content, fast=True, mode=mode) return black.format_file_contents(content, fast=True, mode=mode)

View File

@ -1,6 +1,7 @@
## Running Segment Anything Example ## Running Segment Anything Example
Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes. Here, we provide an example showing how to run the Segment Anything model in the
browser.
### Vanilla JS and WebWorkers ### Vanilla JS and WebWorkers