Metadata-Version: 2.4
Name: mlxsummary
Version: 0.1.0
Summary: Model inspection and summary tools for MLX neural networks
Keywords: mlx,apple,machine-learning,deep-learning,neural-network,model-summary,pytorch-summary
Author: Dhruv Shrivastava
Author-email: Dhruv Shrivastava <dhruvshrivastava@hotmail.com>
License-Expression: MIT
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Dist: mlx>=0.1.0
Requires-Dist: pytest>=7.0 ; extra == 'dev'
Requires-Dist: pytest-cov>=4.0 ; extra == 'dev'
Requires-Dist: black>=23.0 ; extra == 'dev'
Requires-Dist: ruff>=0.1.0 ; extra == 'dev'
Requires-Dist: mypy>=1.0 ; extra == 'dev'
Requires-Python: >=3.9
Project-URL: Documentation, https://github.com/dhruvshr/mlxsummary#readme
Project-URL: Homepage, https://github.com/dhruvshr/mlxsummary
Project-URL: Issues, https://github.com/dhruvshr/mlxsummary/issues
Project-URL: Repository, https://github.com/dhruvshr/mlxsummary
Provides-Extra: dev
Description-Content-Type: text/markdown

# mlxsummary

Model inspection and summary tools for [MLX](https://github.com/ml-explore/mlx) neural networks on Apple Silicon.

Inspired from and similar to `torchsummary` for PyTorch, but designed specifically for MLX's module system.

## Features

- 📊 **Multiple output formats**: Table, Tree, JSON, Markdown, Minimal
- 🔍 **Detailed inspection**: Layer paths, parameter counts, shapes
- 🎯 **Filtering**: Find layers by type, name pattern, or parameter count
- 📈 **Statistics**: Aggregate stats by layer type
- 🖥️ **CLI support**: Use from command line or Python
- 🧊 **Freeze-aware**: Track trainable vs frozen parameters

## Installation

```bash
pip install mlxsummary
```

Or install from source:

```bash
git clone https://github.com/dhruvshr/mlxsummary.git
cd mlxsummary
pip install -e .
```

## Quick Start

```python
import mlx.nn as nn
from mlxsummary import summary

# Create a model
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

# Print summary
summary(model)
```

Output:
```
=============================================================================================================================
 Model Summary: Sequential
=============================================================================================================================
Layer                                         Type                         Params Details                 Trainable
-----------------------------------------------------------------------------------------------------------------------------
layers.0                                      Linear                      200,960 (784 → 256)               200,960
layers.1                                      ReLU                              0                                 0
layers.2                                      Dropout                           0                                 0
layers.3                                      Linear                       32,896 (256 → 128)                32,896
layers.4                                      ReLU                              0                                 0
layers.5                                      Linear                        1,290 (128 → 10)                  1,290
-----------------------------------------------------------------------------------------------------------------------------
Total Parameters:                                                         235,146
Trainable Parameters:                                                     235,146
=============================================================================================================================
```

## Output Formats

### Table (default)
```python
summary(model, format="table")
```

### Tree
```python
summary(model, format="tree")
```
```
📦 Sequential (235,146 params)
│   ├── 0: Linear (784 → 256) [200,960]
│   ├── 1: ReLU [0]
│   ├── 2: Dropout [0]
│   ├── 3: Linear (256 → 128) [32,896]
│   ├── 4: ReLU [0]
│   ├── 5: Linear (128 → 10) [1,290]
```

### JSON
```python
data = summary(model, format="json", print_output=False)
```

### Markdown
```python
summary(model, format="markdown")
```

### Minimal (one-line)
```python
summary(model, format="minimal")
# Output: Sequential: 235,146 params (6 layers)
```

## Programmatic API

### Inspector

For detailed programmatic access:

```python
from mlxsummary import inspect

inspector = inspect(model)

# Get all layers
layers = inspector.get_layers()
for layer in layers:
    print(f"{layer.path}: {layer.total_params:,} params")

# Get statistics
stats = inspector.get_stats()
print(f"Total: {stats.total_params:,}")
print(f"Layer types: {stats.layer_type_counts}")

# Find specific layers
linear_layers = inspector.find_layers(layer_type=nn.Linear)
attention = inspector.find_layers(name_pattern="attention")
large_layers = inspector.find_layers(min_params=10000)
```

### Convenience Functions

```python
from mlxsummary import count_params, get_layers, get_stats, to_dict

# Count parameters
total = count_params(model)
trainable = count_params(model, trainable_only=True)

# Get layers
layers = get_layers(model)
linear_layers = get_layers(model, layer_type=nn.Linear)

# Get stats
stats = get_stats(model)

# Export to dict
data = to_dict(model)
```

## Options

```python
summary(
    model,
    format="table",           # Output format
    show_shapes=True,         # Show layer dimensions
    show_trainable=True,      # Show trainable params column
    show_frozen=False,        # Show frozen params column
    max_depth=None,           # Limit layer depth
    max_rows=None,            # Limit output rows
    include_zero_param=True,  # Include zero-param layers
    width=100,                # Output width
    print_output=True,        # Print vs return only
)
```

## Command Line

```bash
# Summarize a model from a file
mlxsummary model.py

# Different formats
mlxsummary model.py --format tree
mlxsummary model.py --format json -o model.json
mlxsummary model.py --format markdown

# Options
mlxsummary model.py --max-depth 2
mlxsummary model.py --hide-zero
mlxsummary model.py --no-shapes

# Demo mode
mlxsummary --demo
```

Your model file should define a `model` variable or `get_model()` function:

```python
# model.py
import mlx.nn as nn

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)
```

## API Reference

### Classes

| Class | Description |
|-------|-------------|
| `MLXInspector` | Main inspector for detailed model analysis |
| `LayerInfo` | Information about a single layer |
| `ModelStats` | Aggregate statistics about a model |
| `FormatterOptions` | Options for output formatting |
| `OutputFormat` | Enum of available output formats |

### Functions

| Function | Description |
|----------|-------------|
| `summary(model, ...)` | Generate and print a model summary |
| `inspect(model)` | Create an inspector instance |
| `count_params(model)` | Count model parameters |
| `get_layers(model)` | Get list of layer information |
| `get_stats(model)` | Get aggregate statistics |
| `to_dict(model)` | Export model info to dictionary |
| `tree(model)` | Shortcut for tree format |
| `table(model)` | Shortcut for table format |

## Requirements

- macOS with Apple Silicon (M1/M2/M3/M4)
- Python 3.9+
- MLX 0.1.0+

## License

MIT License - see LICENSE file for details.

## Contributing

Contributions are welcome! Please open an issue or pull request.

## See Also

- [MLX Documentation](https://ml-explore.github.io/mlx/)
- [MLX GitHub](https://github.com/ml-explore/mlx)
- [torchsummary](https://github.com/sksq96/pytorch-summary) - Similar tool for PyTorch

