Source code for fracdiff.torch.module

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Module

from . import functional

[docs]class Fracdiff(Module): """A ``torch.nn.Module`` to compute fractional differentiation. Args: d (float): The order of differentiation. dim (int, default=-1): The dimension to differentiate. Currently, only the last dimension is supported. window (int, default=10): The window size for fractional differentiation. mode (str, default="same"): "same" or "valid". See :func:`fracdiff.fdiff` for details. Shape: - input: :math:`(N, *, L_{\\mathrm{in}})`, where where :math:`*` means any number of additional dimensions. - output: :math:`(N, *, L_{\\mathrm{out}})`, where :math:`L_{\\mathrm{out}}` is given by :math:`L_{\\mathrm{in}}` if `mode="same"` and :math:`L_{\\mathrm{in}} - \\mathrm{window} + 1` if `mode="valid"`. Examples: >>> from fracdiff.torch import Fracdiff >>> m = Fracdiff(0.5) >>> m Fracdiff(0.5, dim=-1, window=10, mode='same') >>> input = torch.arange(10).reshape(2, 5) >>> m(input) tensor([[0.0000, 1.0000, 1.5000, 1.8750, 2.1875], [5.0000, 3.5000, 3.3750, 3.4375, 3.5547]]) """ def __init__( self, d: float, dim: int = -1, window: int = 10, mode: str = "same" ) -> None: super().__init__() self.d = d self.dim = dim self.window = window self.mode = mode def extra_repr(self) -> str: params = ( str(self.d), f"dim={self.dim}", f"window={self.window}", f"mode='{self.mode}'", ) return ", ".join(params)
[docs] def forward( self, input: Tensor, prepend: Optional[Tensor] = None, append: Optional[Tensor] = None, ) -> Tensor: """Apply fractional differentiation. Args: input (torch.Tensor): The input tensor. prepend (torch.Tensor, optional): The tensor to prepend to `input` along `self.dim` before computing the differentiation. Their dimensions must be equivalent to that of `input`, and their shapes must match `input`'s shape except on `dim`. append (torch.Tensor, optional): The tensor to append. Returns: torch.Tensor """ return functional.fdiff( input, self.d, dim=self.dim, window=self.window, mode=self.mode, prepend=prepend, append=append, )