Metadata-Version: 2.1
Name: grouped_gemm
Version: 0.1.6
Summary: Grouped GEMM
Home-page: https://github.com/tgale06/grouped_gemm
Author: Trevor Gale
Author-email: tgale@stanford.edu
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: Unix
Description-Content-Type: text/markdown
License-File: LICENSE
Provides-Extra: dev
Requires-Dist: absl-py; extra == "dev"
Provides-Extra: all
Requires-Dist: absl-py; extra == "all"

# Grouped GEMM

A lighweight library exposing grouped GEMM kernels in PyTorch.

# Installation

Run `pip install grouped_gemm` to install the package.

# Compiling from source

By default, the installed package runs in conservative (`cuBLAS`) mode:
it launches one GEMM kernel per batch element instead of using a single
grouped GEMM kernel for the whole batch.

To enable using grouped GEMM kernels, you need to switch to the `CUTLASS`
mode by setting the `GROUPED_GEMM_CUTLASS` environment variable to `1`
when building the library. For example, to build the library in `CUTLASS`
mode for Ampere (SM 8.0), clone the repository and run the following:

```bash
$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install .
```

See [this comment](https://github.com/tgale96/grouped_gemm/pull/14#issuecomment-2211362572)
for some performance measurements on A100 and H100.

# Upcoming features

* Running grouped GEMM kernels without GPU<->CPU synchronization points.
* Hopper-optimized grouped GEMM kernels.
