Metadata-Version: 2.4
Name: musical-mel-transform
Version: 0.1.0
Summary: A PyTorch-based musical mel-frequency transform for audio processing
Project-URL: Homepage, https://github.com/worldveil/musical_mel_transform_torch
Project-URL: Repository, https://github.com/worldveil/musical_mel_transform_torch.git
Project-URL: Issues, https://github.com/worldveil/musical_mel_transform_torch/issues
Author-email: Your Name <your.email@example.com>
License: MIT License
        
        Copyright (c) 2024 Will Drevo
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Keywords: audio,mel,music,pytorch,signal-processing,transform
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.8
Requires-Dist: black>=23.0.0
Requires-Dist: flake8>=5.0.0
Requires-Dist: ipython>=8.12.3
Requires-Dist: isort>=5.0.0
Requires-Dist: librosa>=0.9.0
Requires-Dist: matplotlib>=3.7.5
Requires-Dist: mypy>=1.0.0
Requires-Dist: numpy>=1.21.0
Requires-Dist: onnx>=1.16.0
Requires-Dist: onnxruntime>=1.17.0
Requires-Dist: onnxscript>=0.1.0.dev20240617
Requires-Dist: pre-commit>=3.0.0
Requires-Dist: pytest-cov>=4.0.0
Requires-Dist: pytest>=7.0.0
Requires-Dist: scipy>=1.7.0
Requires-Dist: seaborn>=0.11.0
Requires-Dist: sphinx-rtd-theme>=1.0.0
Requires-Dist: sphinx>=5.0.0
Requires-Dist: torch>=2.1.0
Requires-Dist: torchaudio>=2.1.0
Requires-Dist: tqdm>=4.64.0
Description-Content-Type: text/markdown

# Musical Mel Transform

[![CI](https://github.com/worldveil/musical_mel_transform_torch/actions/workflows/ci.yml/badge.svg)](https://github.com/worldveil/musical_mel_transform_torch/actions)
[![PyPI version](https://badge.fury.io/py/musical-mel-transform.svg)](https://badge.fury.io/py/musical-mel-transform)
[![Python versions](https://img.shields.io/pypi/pyversions/musical-mel-transform.svg)](https://pypi.org/project/musical-mel-transform/)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

A PyTorch-based musical mel-frequency transform for audio processing, optimized for performance and ONNX compatibility.

If you've ever wanted features in torch from an audio signal that directly represent semitones (or quarter tones!) this is the package for you.

![Musical Mel Transform Filterbank](img/low_freq_filters.png)

Here, I show a few mel filters (the notes in the legend). Each gets some weight from neighboring FFT bins in such a way where the position between the FFT bins affects the weighting of the resulting mel bin so that a neural network can differentiate mel bin musical notes from one another.

## What is this and why use it?

All mel scales in torchaudio, librosa, or other packages are also logarithmically spaced bins on the frequency spectrum, but:

* Don't map to western musical scale notes
* Still have poor resolution in the low end range (especially as FFT number < 2048)
* Tend to use much of the feature count on high frequencies as the linearly-spaced FFT bins start to have a many-to-one relationship with musical tones
* Often are not ONNX-compatible due to usage of complex numbers in the FFT

This package aims to alleviate these issues!

## How does it work?

Mel scale is just a mapping of FFT bins -> new bins. So each mel bin is just a weighted sum of the usual linearly-spaced FFT bins. That's it!

This code does some adaptive widening to pick great weighted combinations of FFT bins to make pitches discernable for a downstream layer in your network. You can shorten or widen your tone granulariy -- so semi- or quarter- tones is just a parameter change.

It also uses a nice [convolutional FFT](src/musical_mel_transform/conv_fft.py) to apply the FFT in a way that is ONNX export compatible, if that is important to you. If not, just stick with the torch native FFT - it is faster.

## Usage example

### MusicalMelTransform

```python
import torch
from musical_mel_transform import MusicalMelTransform

MusicalMelTransform(
    sample_rate=44100,            # Audio sample rate
    frame_size=2048,              # FFT size
    interval=1.0,                 # Musical interval in semitones
    f_min=80,                     # Minimum frequency (Hz)
    f_max=8_000,                  # Maximum frequency (Hz)
    passthrough_cutoff_hz=10_000, # High-freq passthrough threshold
    norm=True,                    # Normalize filterbank
    min_bins=2,                   # Minimum filter width
    adaptive=True,                # Adaptive filter sizing
    passthrough_grouping_size=3,  # High-freq bin grouping
    use_conv_fft=True,            # Use convolution-based FFT
    learnable_weights="mel"       # What kind, if any, of learnable weights to use
)
batch_size, frame_size = 4, 2048
audio_frames = torch.randn(batch_size, frame_size)
mel_feats, fft_mag_feats = mel_transform(audio_fraames)
```

## Quickstart in a toy torch network

```python
import torch
import torch.nn as nn
from musical_mel_transform import MusicalMelTransform

class SimpleAudioClassifier(nn.Module):
    """Simple audio classifier using musical mel features"""
    def __init__(self, frame_size: int, n_classes: int = 10):
        super().__init__()

        self.mel_transform = MusicalMelTransform(
            # sample rate in Hz
            sample_rate=44100,

            # number of samples to include / the size of the FFT
            # by default a hann window is applied as well
            frame_size=frame_size,

            # this is the interval in semitones with which to place a mel bin
            interval=0.5,  # quarter tone resolution

            # above this frequency threshold, don't try to create mel bins,
            # simply let the FFT bins "pass through". this is useful because as
            # frequency increases, mel bins become spaced at *larger* intervals
            # than FFT bins -- the opposite of low frequencies!
            passthrough_cutoff_hz=8_000,

            # dimensionality reduction technique. with passthrough FFT bins, there can be a lot
            # this simply groups together them in order. so a `passthrough_grouping_size=2`
            # would reduce your passthrough FFT bins by half. =1 would simply pass them
            # all through, unchanged.
            passthrough_grouping_size=3,

            # ignore FFT bins altogether above this frequency
            f_max=10_000,

            # this argument ensures that each mel bin is normalized to have unit weight
            norm=True,

            # the minimum number of FFT bins to widen to in order to fill a mel bin
            # definitely want this >1, otherwise you'll just be copying nearest FFT bin!
            min_bins=2,

            # highly recommend this to be True. This does some adaptive widening to ensure that
            # both high and low frequency bins have around the same "spread" to neighboring FFT bins
            adaptive=True,

            # this will slow down your network! but it will allow you to
            # export ONNX with dynamo=True and still have FFT. whenever onnx folks
            # finally add complex number support this won't be needed anymore,
            # but until then....
            use_conv_fft=True,

            # what kind, if any, learnable weights to use. Options are: None, "fft", "mel"
            # if "fft" then a learnable parameter is used to elementwise multiply the raw
            # FFT bins before mel transform. if "mel", then elementwise multiply is after.
            # If None, then no reweighting is done!
            learnable_weights="mel",
        )

        # Simple classifier head
        self.classifier = nn.Sequential(
            nn.Linear(self.mel_transform.n_mel, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, n_classes)
        )

    def forward(self, audio_frame):
        """
        Args:
            audio_frame: [batch_size, frame_size] - frames of audio
        Returns:
            logits: [batch_size, n_classes] - Classification scores
        """
        mel_features, _ = self.mel_transform(audio_frame)
        return self.classifier(mel_features)

# Example usage
model = SimpleAudioClassifier(n_classes=10, frame_size=2048)

# Process a batch of audio frames
batch_size = 4
frame_size = 2048
audio_frames = torch.randn(batch_size, frame_size)

with torch.no_grad():
    predictions = model(audio_frames)
    print(f"Predictions shape: {predictions.shape}")  # [4, 10]

# Export to ONNX
torch.onnx.export(
    model,
    audio_frames,
    "audio_classifier.onnx",
    export_params=True,
    opset_version=18,
    input_names=['audio_frame'],
    output_names=['predictions'],
    dynamic_axes=None,  # I recommend this for best performance
)
print("Model exported to audio_classifier.onnx")
```

## TODO

* [ ] Easy way to demo plotting a piece of audio or mp3 using a set of MusicalMelTransform() init params
* [ ] Speed ups for `ConvFFT` module

## Installation

```bash
pip install musical-mel-transform

# or editable install
cd musical_mel_transform/
pip install -e .
```

### Different Musical Scales

`interval` is in units of semitones.

```python
chromatic_transform = MusicalMelTransform(interval=1.0)
quarter_tone_transform = MusicalMelTransform(interval=0.5)
```

### Visualization

```python
from musical_mel_transform import MusicalMelTransform, plot_low_filters

# Visualize the mel filterbank
transform = MusicalMelTransform()

plot_low_filters(
    transform,
    bank_idx_to_show=[0, 5, 10, 15, 20, 25],
    x_max_hz=1000,
    legend=True
)
```

## Demo Script

Run the interactive demo to explore different features:

```bash
# Run all demos
musical-mel-demo

# Run specific demos
musical-mel-demo --demo basic           # Basic usage
musical-mel-demo --demo params          # Parameter comparison
musical-mel-demo --demo filters         # Filterbank visualization
musical-mel-demo --demo performance     # Performance benchmarks
musical-mel-demo --demo onnx            # ONNX export
musical-mel-demo --demo musical         # Musical analysis

# Skip plot generation
musical-mel-demo --no-plots
```

## Running Tests

```bash
# Run all tests
pytest

# Run with coverage
pytest --cov=musical_mel_transform --cov-report=html

# Run specific test categories
pytest tests/test_fft.py -v                    # FFT tests
pytest tests/test_onnx_export.py -v           # ONNX tests
pytest -m "not slow" -v                       # Skip slow tests
pytest -m integration -v                      # Integration tests only
```

### Code Quality

```bash
# Format code
black src/ tests/
isort src/ tests/
```

### Performance Testing

```bash
# Run performance benchmarks
pytest tests/test_fft.py::test_exact_fft_transform_matches_torch_rfft -v
pytest tests/test_fft.py::test_exact_mel_transform_matches_torch_rfft -v

# Extended performance tests
pytest -m slow -v
```

## Deployment

### Publish a new version to PyPI (step-by-step)

```bash
pip install --upgrade build twine
```

```bash
# 1) Bump version
# Edit src/musical_mel_transform/__init__.py (e.g., __version__ = "0.1.1")
```

```bash
# 2) Clean previous builds
rm -rf dist/ build/ *.egg-info
```

```bash
# 3) Build wheel and sdist
python -m build
```

```bash
# 4) Verify the artifacts
python -m twine check dist/*
```

```bash
# 5) (Optional) Upload to Test PyPI and verify install
python -m twine upload --repository testpypi dist/*
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps musical-mel-transform
```

```bash
# 6) Upload to PyPI
python -m twine upload dist/*
```

### GitHub Release Process

1. **Update Version**: Bump version in `src/musical_mel_transform/__init__.py`
2. **Create Tag**:
   ```bash
   git tag v0.1.0
   git push origin v0.1.0
   ```
3. **Create Release**: Go to GitHub releases and create a new release from the tag
4. **Automated Deployment**: The CI will automatically build and deploy to PyPI

### Environment Variables for CI/CD

Set these secrets in your GitHub repository:

- `PYPI_API_TOKEN`: Your PyPI API token for automated publishing
- `CODECOV_TOKEN`: (Optional) For code coverage reporting

## API Reference

### MusicalMelTransform

```python
MusicalMelTransform(
    sample_rate: int = 44100,           # Audio sample rate
    frame_size: int = 2048,             # FFT size
    interval: float = 1.0,              # Musical interval in semitones
    f_min: float = 80.0,                # Minimum frequency (Hz)
    f_max: Optional[float] = None,      # Maximum frequency (Hz)
    passthrough_cutoff_hz: float = 10000, # High-freq passthrough threshold
    norm: bool = True,                  # Normalize filterbank
    min_bins: int = 2,                  # Minimum filter width
    adaptive: bool = True,              # Adaptive filter sizing
    passthrough_grouping_size: int = 3, # High-freq bin grouping
    use_conv_fft: bool = False,         # Use convolution-based FFT
    learnable_weights: str = None       # What kind, if any, of learnable weights to use
)
```

### ConvFFT

```python
ConvFFT(
    frame_size: int, # size of FFT
    window_type: str = {None, "hann", "hamming"}
)
```

ONNX-compatible FFT implementation using matrix multiplication.

## Performance

Performance comparison between different FFT implementations on my M1 Macbook using CPU:

| Configuration | Time (ms) |
|---------------|-----------|
| Torch FFT (frame_size=1024) | 0.01 +/- 0.01 ms |
| Conv FFT (frame_size=1024) | 0.08 +/- 0.05 ms |
| Torch FFT (frame_size=2048) | 0.02 +/- 0.01 ms |
| Conv FFT (frame_size=2048) | 0.39 +/- 0.05 ms|

*Note: Conv FFT trades some speed for ONNX compatibility and consistent cross-platform behavior.*

### Development Workflow

1. Fork the repository
2. Create a feature branch: `git checkout -b feature-name`
3. Make your changes and add tests
4. Run the test suite: `pytest`
5. Check code quality
6. Submit a pull request

## Citation

If you use this library in your research, please cite:

```bibtex
@software{musical_mel_transform,
  title={Musical Mel Transform: PyTorch-based musical mel-frequency transform},
  author={Will Drevo},
  url={https://github.com/worldveil/musical_mel_transform_torch},
  year={2025}
}
```

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
