Metadata-Version: 2.2
Name: fsgdm
Version: 1.0
Summary: The official implementation of the Frequency SGD with Momentum (FSGDM) optimizer in PyTorch.
Home-page: https://github.com/yinleung/FSGDM
Author: Xianliang Li
Author-email: yinleung.ley@gmail.com
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=1.7.0
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

# FSGDM

**FSGDM** (Frequency Stochastic Gradient Descent with Momentum) is an optimizer implemented in PyTorch that dynamically adjusts the momentum filtering characteristics using an empirically effective dynamic magnitude response.

Paper: [On the Performance Analysis of Momentum Method: A Frequency Domain Perspective](https://openreview.net/forum?id=tznvtmSEiN)

Authors: Xianliang Li, Jun Luo, Zhiwei Zheng, Hanxiao Wang, Li Luo, Lingkun Wen, Linlong Wu, Sheng Xu

This repository contains the official PyTorch implementation of FSGDM.

---

## Usage

Install torch and run

```
pip install fsgdm
```

or simply copy the `fsgdm.py` file to your codebase.

Then use the FSGDM optimizer in the following fashion

```
from fsgdm import fsgdm

optimizer = FSGDM(
            model.parameters(), 
            lr = lr,                        
            weight_decay = weight_decay,
            c_scaling = c_scaling, 
            v_coeffcicent = v_coeffcicent, 
            n_stages = n_stages, 
            sigma = sigma
            )
```

Replace `lr`, `weight_decay`, `c_scaling`, `v_coefficient`, `n_stages`, and `sigma` with values suitable for your task.

**Hyperparameter choices:**

- lr & weight_decay: We recommend using the same values as those typically used for SGDM in PyTorch.
- c_scaling & v_coefficient: These parameters should lie in the **optimal zone**. For CNNs, a good rule of thumb is to aim for the region where `30.992/v_coefficient ≈ 1 + 1/c_scaling`.
- n_stages: The number of training stages. Users can experiment with different values to find the best configuration for specific tasks.
- sigma: The number of gradient update steps. Users need to compute this value manually.

**Remark: The optimal hyperparameter zones can vary across different tasks. We welcome contributions that explore these optimal zones for various learning tasks.**

### Examples

Examples of using the `fsgdm` package can be found in the `examples` folder. These include:

- [Image classification (CIFAR-100) using ResNet50](./examples/CIFAR100/)*
- More examples to be added

*This example is modified from the code generated by GPT-4o.

## Frequency Domain Analysis Framework in a Nutshell

The momentum method can be interpreted as a **time-invariant filter for gradients**, where adjustments to momentum coefficients modify the filter characteristics.

The high-frequency gradient components correspond to large and more abrupt changes in the gradient; while the low-frequency components indicate smooth and more gradual adjustments.

Significant findings for DNN training:

- High-frequency gradient components are undesired in the late stages of training
- Preserving the original gradient in the early stages improves performance
- Gradually amplifying low-frequency gradient components during training enhances performance

*For a more detailed explanation, please refer to our paper.*

## Paper

If you find FSGDM useful in your research, please cite our paper in the following format.

```
@inproceedings{
li2025on,
title={On the Performance Analysis of Momentum Method: A Frequency Domain Perspective},
author={Xianliang Li and Jun Luo and Zhiwei Zheng and Hanxiao Wang and Li Luo and Lingkun Wen and Linlong Wu and Sheng Xu},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=tznvtmSEiN}
}
```

## License

See the [License file](/LICENSE).
