Source code for dtw_loss_functions.soft_dtw

"""
Since there are various implementations online of the SoftDTW, this module was created as a wrapper to be able to easily switch between different implementations.

Currently, the following implementations are available :

- pytorch-softdtw-cuda by Mehran Maghoumi :cite:`soft_dtw_mag_1` :cite:`soft_dtw_mag_2`
    - GitHub repository : https://github.com/Maghoumi/pytorch-softdtw-cuda
- pysdtw by Antoine Loriette
    - GitHub repository : https://github.com/toinsson/pysdtw
    - PyPi Page : https://pypi.org/project/pysdtw/
- sdtw-cuda-torch by BGU-CS-VIL (implemented by Ron Shapira Weber) :cite:`soft_dtw_ron_1` :cite:`soft_dtw_ron_2`
    - GitHub repository : https://github.com/BGU-CS-VIL/sdtw-cuda-torch

If you use this module, please cite together with this package the original paper of the implementation you are using.

Authors
-------
Alberto Zancanaro <alberto.zancanaro@uni.lu>

"""

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

import torch

import pysdtw
from .soft_dtw_implementations import soft_dtw_cuda_mag, pysdtw_normalize, soft_dtw_cuda_ron

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

[docs] class soft_dtw(torch.nn.Module) : """ SoftDTW class. This class is a wrapper for the different implementations of the SoftDTW. The implementation can be selected by passing the 'implementation' argument to the constructor. The available implementations are: - `mag`: pytorch-softdtw-cuda by Mehran Maghoumi - `pysdtw`: pysdtw by Antoine Loriette - `ron` : sdtw-cuda-torch by BGU-CS-VIL (implemented by Ron Shapira Weber) Parameters ---------- use_cuda : bool If true, this class will use the CUDA implementation of the SDTW. gamma : float, optional Value of the gamma hyperparameter for the SDTW. Default is ``1``. normalize : bool, optional If true, the SDTW divergence will be computed instead of the SDTW. Default is ``False``. bandwidth : float, optional Sakoe-Chiba bandwidth for pruning. If the ``None`` is given, no pruning is applied. Default is None. dist_func : function, optional Distance function to use for the SDTW. Default is ``None``, which corresponds to the squared Euclidean distance. implementation : str, optional Implementation to use for the SDTW. - ``mag``. Use the implementation by Mehran Maghoumi :cite:`soft_dtw_mag_1` :cite:`soft_dtw_mag_2`. - ``pysdtw``. Use the implementation in the pysdtw package by Antoine Loriette. - ``ron`` Use the implementation by Ron Shapira Weber :cite:`soft_dtw_ron_1` :cite:`soft_dtw_ron_2`. Default is ``mag``. fused : bool, optional Only for the 'ron' implementation. - ``None`` -> auto (use fused only when possible) - ``True`` -> require fused (error if not possible) - ``False`` -> never fused (always materialize D and use D-based autograd) """ def __init__(self, use_cuda : bool, gamma : float = 1, normalize : bool = False, bandwidth : int = None, dist_func : callable = None, implementation : str = 'mag', fused : bool = None) : super().__init__() if implementation == 'mag' : self.sdtw_function = soft_dtw_cuda_mag.SoftDTW(use_cuda = use_cuda, gamma = gamma, normalize = normalize, bandwidth = bandwidth, dist_func = dist_func) elif implementation == 'pysdtw' : if normalize : self.sdtw_function = pysdtw_normalize.pysdtw_normalized(use_cuda = use_cuda, gamma = gamma, bandwidth = bandwidth, dist_func = dist_func) else : self.sdtw_function = pysdtw.SoftDTW(use_cuda = use_cuda, gamma = gamma, bandwidth = bandwidth, dist_func = dist_func) elif implementation == 'ron' : self.sdtw_function = soft_dtw_cuda_ron.SoftDTW(gamma = gamma, normalize = normalize, bandwidth = bandwidth, dist_func = dist_func, fused = fused)
[docs] def forward(self, x : torch.Tensor, y : torch.Tensor) -> torch.Tensor : """ Compute the SoftDTW distance between two time series. Parameters ---------- x : torch.Tensor First input tensor of shape B x T x C y : torch.Tensor Second input tensor of shape B x T x C Returns ------- torch.Tensor SoftDTW distance between the two time series """ return self.sdtw_function(x, y)
[docs] def check_implementation(self, implementation : str) : """ Check if the selected implementation is valid. If not, raise an error. """ implementations = ['mag', 'pysdtw', 'ron'] if implementation not in implementations : raise ValueError(f"Invalid implementation selected. Implementations available: {implementations}. Selected implementation: {implementation}.")