Soft DTW

Since there are various implementations online of the SoftDTW, this module was created as a wrapper to be able to easily switch between different implementations.

Currently, the following implementations are available :

If you use this module, please cite together with this package the original paper of the implementation you are using.

Authors

Alberto Zancanaro <alberto.zancanaro@uni.lu>

class dtw_loss_functions.soft_dtw.soft_dtw(use_cuda: bool, gamma: float = 1, normalize: bool = False, bandwidth: int = None, dist_func: callable = None, implementation: str = 'mag', fused: bool = None)[source]

Bases: Module

SoftDTW class. This class is a wrapper for the different implementations of the SoftDTW. The implementation can be selected by passing the ‘implementation’ argument to the constructor. The available implementations are:

  • mag: pytorch-softdtw-cuda by Mehran Maghoumi

  • pysdtw: pysdtw by Antoine Loriette

  • ron : sdtw-cuda-torch by BGU-CS-VIL (implemented by Ron Shapira Weber)

Parameters:
  • use_cuda (bool) – If true, this class will use the CUDA implementation of the SDTW.

  • gamma (float, optional) – Value of the gamma hyperparameter for the SDTW. Default is 1.

  • normalize (bool, optional) – If true, the SDTW divergence will be computed instead of the SDTW. Default is False.

  • bandwidth (float, optional) – Sakoe-Chiba bandwidth for pruning. If the None is given, no pruning is applied. Default is None.

  • dist_func (function, optional) – Distance function to use for the SDTW. Default is None, which corresponds to the squared Euclidean distance.

  • implementation (str, optional) –

    Implementation to use for the SDTW.

    • mag. Use the implementation by Mehran Maghoumi [Mag20] [MTL21].

    • pysdtw. Use the implementation in the pysdtw package by Antoine Loriette.

    • ron Use the implementation by Ron Shapira Weber [WF26] [SWBL+25].

    Default is mag.

  • fused (bool, optional) –

    Only for the ‘ron’ implementation.

    • None -> auto (use fused only when possible)

    • True -> require fused (error if not possible)

    • False -> never fused (always materialize D and use D-based autograd)

Methods

check_implementation(implementation)

Check if the selected implementation is valid.

forward(x, y)

Compute the SoftDTW distance between two time series.

check_implementation(implementation: str)[source]

Check if the selected implementation is valid. If not, raise an error.

forward(x: Tensor, y: Tensor) Tensor[source]

Compute the SoftDTW distance between two time series.

Parameters:
  • x (torch.Tensor) – First input tensor of shape B x T x C

  • y (torch.Tensor) – Second input tensor of shape B x T x C

Returns:

SoftDTW distance between the two time series

Return type:

torch.Tensor