Metadata-Version: 2.1
Name: simple-cats
Version: 0.1
Summary: Python package for CATS paper
Home-page: https://github.com/ScalingIntelligence/CATS
Author: Jeyong Lee, Donghyun Lee, ...
Author-email: je-yong.lee@worc.ox.ac.uk
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.6
Description-Content-Type: text/markdown
Requires-Dist: dataclasses
Requires-Dist: filelock~=3.9.0
Requires-Dist: numpy~=1.23.5
Requires-Dist: tqdm~=4.64.1
Requires-Dist: packaging~=22.0
Requires-Dist: requests~=2.28.1
Requires-Dist: importlib_metadata~=4.11.3
Requires-Dist: regex~=2022.7.9
Requires-Dist: pandas~=1.2.5
Requires-Dist: pyyaml
Requires-Dist: python-dateutil~=2.8.2
Requires-Dist: setuptools~=65.6.3
Requires-Dist: rouge_score~=0.1.2
Requires-Dist: absl-py~=1.4.0
Requires-Dist: nltk~=3.8.1
Requires-Dist: sacremoses
Requires-Dist: attrs==23.1.0
Requires-Dist: protobuf<4.0,>=3.6
Requires-Dist: portpicker>=1.3.1
Requires-Dist: grpcio<=1.48.2,>=1.35.0
Requires-Dist: grpcio-tools<=1.48.2,>=1.35.0
Requires-Dist: googleapis-common-protos>=1.56.4
Requires-Dist: sqlalchemy<=1.4.20,>=1.4
Requires-Dist: evaluate
Requires-Dist: scikit-learn
Requires-Dist: optimum
Requires-Dist: onnx
Requires-Dist: onnxruntime
Requires-Dist: jax>=0.4.10
Requires-Dist: jaxlib>=0.4.10
Requires-Dist: jaxopt>=0.7
Requires-Dist: flax>=0.6.10
Requires-Dist: optax>=0.1.5
Requires-Dist: chex>=0.1.7
Requires-Dist: py7zr
Requires-Dist: accelerate
Requires-Dist: wandb
Requires-Dist: deepspeed
Requires-Dist: equinox
Requires-Dist: transformers==4.36.2
Requires-Dist: triton==2.1.0
Requires-Dist: huggingface_hub
Requires-Dist: tokenizers
Requires-Dist: datasets
Requires-Dist: testresources
Requires-Dist: mpi4py-mpich
Requires-Dist: seaborn
Requires-Dist: torch
Requires-Dist: peft
Requires-Dist: lm_eval
Requires-Dist: trl

This repository contains the official implementation of "CATS: Contextually-Aware Thresholding for Sparsity in Large Language Models" by Je-Yong Lee, Donghyun Lee, Genghan Zhang, Mo Tiwari, and Azalia Mirhoseini, as described in our paper on [arXiv](https://arxiv.org/abs/2404.08763).

## Overview
Our paper, "CATS: Contextually-Aware Thresholding for Sparsity in Large Language Models," introduces CATS—a new method aimed at reducing the computational demands of deploying LLMs without sacrificing their performance on downstream tasks. This method centers around a novel activation function that enhances activation sparsity effectively and efficiently.

The CATS approach can be applied to various base models such as Mistral-7B and Llama2-7B, demonstrating a minimal performance drop (within 1-2% of the base models) even at 50% activation sparsity levels. Importantly, CATS not only accelerates convergence but also integrates a custom GPU kernel that enhances inference speeds by approximately 15%.

## Reproducing Results

To reproduce the experimental results and figures presented in our work, please follow the steps outlined below. The process has been simplified into a single script to ensure ease of use and to maintain consistency across different environments.

### Prerequisites

Ensure you have the following prerequisites installed:
- Bash shell (Unix/Linux/Mac)
- Required Python packages (listed in `requirements.txt`)
- Set an `accelerate` configuration file based on your environment by running `accelerate config` 

### Steps

1. Open a terminal in the root directory of the project.
2. Run the following command:

```bash
bash reproduction_script.sh [path1] [path2]
```
- [path1]: Directory where the checkpoints for fine-tuned models will be stored.
- [path2]: Directory where the results of the experiments, such as figures and histograms, will be saved.

## Work in progress
We are currently developing a framework that will enable CATS to be easily integrated with any model from the HuggingFace library. 
