Metadata-Version: 2.4
Name: fswlib
Version: 0.9.7
Summary: A PyTorch library for computing the Fourier Sliced-Wasserstein (FSW) embedding
Author-email: Tal Amir <talamir@technion.ac.il>, Nadav Dym <nadavdym@technion.ac.il>
Maintainer-email: Tal Amir <talamir@technion.ac.il>
License: BSD 3-Clause License
        
        Copyright (c) 2024-2025, Tal Amir
        
        Redistribution and use in source and binary forms, with or without
        modification, are permitted provided that the following conditions are met:
        
        1. Redistributions of source code must retain the above copyright notice, this
           list of conditions and the following disclaimer.
        
        2. Redistributions in binary form must reproduce the above copyright notice,
           this list of conditions and the following disclaimer in the documentation
           and/or other materials provided with the distribution.
        
        3. Neither the name of the copyright holder nor the names of its
           contributors may be used to endorse or promote products derived from
           this software without specific prior written permission.
        
        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
        AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
        IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
        DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
        FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
        DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
        SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
        CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
        OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
        OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Project-URL: Homepage, https://github.com/tal-amir/fswlib
Project-URL: Repository, https://github.com/tal-amir/fswlib
Project-URL: Documentation, https://tal-amir.github.io/fswlib
Keywords: sliced Wasserstein,Fourier Sliced Wasserstein,FSW embedding,multiset embedding
Classifier: Development Status :: 4 - Beta
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.10.3
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy>=1.24.4
Requires-Dist: torch>=2.1.0
Dynamic: license-file

# fswlib: A PyTorch Library for the Fourier Sliced-Wasserstein (FSW) Embedding

This package provides an implementation of the **Fourier Sliced-Wasserstein (FSW) embedding**, introduced in our [ICLR 2025 paper](https://iclr.cc/virtual/2025/poster/30562)

> **Fourier Sliced-Wasserstein Embedding for Multisets and Measures**  
> Tal Amir & Nadav Dym  
> *International Conference on Learning Representations (ICLR)*, 2025

---

## 📦 Requirements

- **Python** ≥ 3.10.3 (released March 2022)  
- **PyTorch** ≥ 2.1.0 (released October 2023)  
- **NumPy** ≥ 1.24.4 (released June 2023)  
  
The core package has been tested on **Linux** and **Windows**.  
It may also run on **macOS**, though this has not been verified.  


---

## 🔧 Installation

To install the package:

```bash
pip install fswlib
```
The core package runs on both **CPU** and **CUDA-enabled GPUs**, using PyTorch's standard CUDA backend.  

In addition, it includes an optional **custom CUDA extension** that can provide up to 2× speedup for sparse weight matrices (e.g., sparse graphs). This extension is currently supported only on **Linux**.

  
To compile the optional extension, run:

```bash
fswlib-build
```


---

## 📘 Usage Example

Below is a basic usage example of the `FSWEmbedding` class.  

For more examples, see the `examples/` [directory](https://github.com/tal-amir/fswlib/tree/main/examples) of the GitHub repository.  
Full API documentation is available at [https://tal-amir.github.io/fswlib](https://tal-amir.github.io/fswlib).


```python
import torch
from fswlib import FSWEmbedding

# Configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32
d = 15     # Dimension of multiset elements
n = 50     # Multiset size
m = 123    # Embedding output dimension

# Create FSW embedding module for multisets/measures over ℝ^d
embed = FSWEmbedding(d_in=d, d_out=m, device=device, dtype=dtype)

# --- Single input multiset ---
X = torch.randn(size=(n, d), device=device, dtype=dtype)
W = torch.rand(n, device=device, dtype=dtype)  # Optional weights

X_emb = embed(X, W)  # Embeds a weighted multiset
X_emb = embed(X)     # Embeds X assuming uniform weights

# --- A batch of input multisets ---
batch_dims = (5,3,7,9)
Xb = torch.randn(size=batch_dims+(n,d), device=device, dtype=dtype)
Wb = torch.rand(batch_dims+(n,), device=device, dtype=dtype)
Xb_emb = embed(Xb, Wb)

print(f"Dimension of multiset elements: {d}")
print(f"Embedding dimension: {m}")
print(f"\nOne multiset X of size {n}:")
print("X shape:", X.shape)
print("embed(X) shape:", X_emb.shape)

batch_dim_str = "×".join(str(b) for b in batch_dims)
print(f"\nBatch of {batch_dim_str} multisets, each of size {n}:")
print("Xb shape:", Xb.shape)
print("embed(Xb) shape:", Xb_emb.shape)
```

Output:
```
Dimension of multiset elements: 15
Embedding dimension: 123

One multiset X of size 50:
X shape: torch.Size([50, 15])
embed(X) shape: torch.Size([123])

Batch of 5×3×7×9 multisets, each of size 50:
Xb shape: torch.Size([5, 3, 7, 9, 50, 15])
embed(Xb) shape: torch.Size([5, 3, 7, 9, 123])
```

The example below illustrates the difference between the core embedding, which is invariant to the input multiset size, and an embedding that explicitly encodes it.
```python
# --- Encoding multiset size (total mass) ---
# By default, the embedding is invariant to the input multiset size, since it
# treats inputs as *probability measures*.
# Set `encode_total_mass = True` to make the embedding encode the size of the
# input multisets, or, more generally, the total mass (i.e. sum of weights).
embed_total_mass_invariant = FSWEmbedding(d_in=d, d_out=m, device=device, dtype=dtype)
embed_total_mass_aware =     FSWEmbedding(d_in=d, d_out=m, encode_total_mass=True, device=device, dtype=dtype)

# Two multisets of different size but identical element proportions
X = torch.rand(3, d, device=device, dtype=dtype)
v1, v2, v3 = X[0], X[1], X[2]

X1 = torch.stack([v1, v2, v3])
X2 = torch.stack([v1, v1, v2, v2, v3, v3])

# Embedding *without* total mass encoding
X1_emb = embed_total_mass_invariant(X1)
X2_emb = embed_total_mass_invariant(X2)

# Embedding *with* total mass encoding
X1_emb_aware = embed_total_mass_aware(X1)
X2_emb_aware = embed_total_mass_aware(X2)

# Measure the differences
diff_invariant = torch.norm(X1_emb - X2_emb).item()
diff_aware = torch.norm(X1_emb_aware - X2_emb_aware).item()

print("Two different-size multisets with identical element proportions:")
print("X₁ = {v1, v2, v3},   X₂ = {v1, v1, v2, v2, v3, v3}")
print("Embedding difference: ‖Embed(X₁) − Embed(X₂)‖₂")
print(f"With total mass encoding:     {diff_aware}")
print(f"Without total mass encoding:  {diff_invariant:.2e}")
```

Output:
```
Two different-size multisets with identical element proportions:
X₁ = {v1, v2, v3},   X₂ = {v1, v1, v2, v2, v3, v3}
Embedding difference: ‖Embed(X₁) − Embed(X₂)‖₂
With total mass encoding:     3.0
Without total mass encoding:  5.09e-07
```

---

## 📄 Citation

If you use this library in your research, please cite our paper:

```bibtex
@inproceedings{amir2025fsw,
  title={Fourier Sliced-{W}asserstein Embedding for Multisets and Measures},
  author={Tal Amir and Nadav Dym},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2025}
}
```

---

## 🔗 Links

- **Paper**: [ICLR 2025](https://iclr.cc/virtual/2025/poster/30562)  
- **Code**: [GitHub repository](https://github.com/tal-amir/fswlib)

---

## 👨🏻‍🔧 Maintainer

This library is maintained by **Tal Amir**  
Homepage: [https://tal-amir.github.io](https://tal-amir.github.io)  
EMail: [talamir@technion.ac.il](mailto:talamir@technion.ac.il)

