# Instructions for Training a Flow Matching Model with Sophisticated Methods and Real Data

This file provides step-by-step instructions for training a WeatherFlow flow matching model
using the most sophisticated methods available with real ERA5 data.

## Prerequisites

1. Python 3.8+ installed
2. CUDA-capable GPU (recommended)
3. At least 16GB RAM
4. Internet connection (for ERA5 data access)

## Step 1: Install Dependencies

```bash
# Install core dependencies
pip install -r requirements.txt

# Or install the package in development mode
pip install -e .
```

## Step 2: Verify Installation

```bash
python -c "from weatherflow.data import ERA5Dataset; print('Installation successful!')"
```

## Step 3: Choose Your Training Approach

### Option A: Simple Flow Matching (Recommended for Beginners)

Use the `examples/flow_matching/era5_strict_training_loop.py` script:

```bash
python examples/flow_matching/era5_strict_training_loop.py \
    --data-root /path/to/era5 \
    --checkpoint-dir ./checkpoints \
    --train-years 2018 2019 \
    --val-years 2020 \
    --variables u_component_of_wind v_component_of_wind \
    --levels 850 500 \
    --batch-size 4 \
    --num-epochs 10
```

### Option B: Advanced Flow Matching with Physics Constraints

Use the `examples/weather_prediction.py` script with sophisticated options:

```bash
python examples/weather_prediction.py \
    --variables z t u v \
    --pressure-levels 850 500 250 \
    --train-years 2015 2016 \
    --val-years 2017 \
    --epochs 20 \
    --use-attention \
    --physics-informed \
    --save-model \
    --save-results
```

### Option C: Foundation Model Pre-training (Most Sophisticated)

Use the FlowAtmosphere foundation model for massive-scale training:

```bash
# Single GPU
python foundation_model/examples/pretrain_flowatmosphere.py

# Multi-GPU (8 GPUs)
torchrun --nproc_per_node=8 \
    foundation_model/examples/pretrain_flowatmosphere.py \
    --config configs/flowatm_10b.yaml
```

## Step 4: Using the FlowTrainer API (Programmatic Approach)

For maximum control, use the FlowTrainer class directly:

```python
import torch
from weatherflow.data import create_data_loaders
from weatherflow.models import WeatherFlowMatch
from weatherflow.training.flow_trainer import FlowTrainer

# Load data
train_loader, val_loader = create_data_loaders(
    variables=['z', 't', 'u', 'v'],
    pressure_levels=[850, 500, 250],
    train_slice=('2015', '2016'),
    val_slice=('2017', '2017'),
    batch_size=16,
    num_workers=4
)

# Create model with sophisticated options
model = WeatherFlowMatch(
    input_channels=12,  # 4 variables × 3 levels
    hidden_dim=256,
    n_layers=6,
    use_attention=True,        # Use attention mechanism
    physics_informed=True,     # Apply physics constraints
    grid_size=(32, 64)        # Lat/lon grid size
)

# Set up trainer with advanced features
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
trainer = FlowTrainer(
    model=model,
    optimizer=optimizer,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    use_amp=True,                    # Mixed precision training
    use_wandb=True,                  # Weights & Biases logging
    checkpoint_dir='./checkpoints',
    physics_regularization=True,     # Physics loss regularization
    physics_lambda=0.05,             # Physics loss weight
    loss_type='huber'                # Robust loss function
)

# Training loop
for epoch in range(20):
    metrics = trainer.train_epoch(train_loader)
    val_metrics = trainer.validate(val_loader)
    
    print(f"Epoch {epoch+1}: Train Loss = {metrics['train_loss']:.4f}, "
          f"Val Loss = {val_metrics['val_loss']:.4f}")
    
    if val_metrics['val_loss'] < trainer.best_val_loss:
        trainer.best_val_loss = val_metrics['val_loss']
        trainer.save_checkpoint('best_model.pt')
```

## Step 5: Troubleshooting Common Issues

### Issue 1: Module Not Found Errors
**Problem**: `ModuleNotFoundError: No module named 'weatherflow'`
**Solution**: Install the package with `pip install -e .` from the repository root

### Issue 2: ERA5 Data Access Errors
**Problem**: Cannot access ERA5 data from WeatherBench2
**Solution**: 
- Check internet connection
- Install gcsfs: `pip install gcsfs zarr`
- Or download ERA5 data locally and point to it with --data-root

### Issue 3: CUDA Out of Memory
**Problem**: `RuntimeError: CUDA out of memory`
**Solution**:
- Reduce batch size: `--batch-size 2`
- Reduce model size: Use fewer layers or smaller hidden_dim
- Enable gradient checkpointing in model configuration

### Issue 4: Import Errors in Examples
**Problem**: `ImportError` when running examples
**Solution**: Make sure you're running from the repository root, or use the fixed versions:
- Use `examples/flow_matching/simple_example_fixed.py` instead of `simple_example.py`

### Issue 5: Disk Space Issues
**Problem**: `OSError: [Errno 28] No space left on device`
**Solution**:
- Clean up temporary files: `rm -rf /tmp/*`
- Use a different checkpoint directory with more space
- Reduce data caching

## Step 6: Validation and Inference

After training, generate predictions:

```python
from weatherflow.models import WeatherFlowODE

# Load trained model
model.load_state_dict(torch.load('checkpoints/best_model.pt'))
model.eval()

# Create ODE solver
ode_solver = WeatherFlowODE(
    flow_model=model,
    solver_method='dopri5',
    rtol=1e-4,
    atol=1e-4
)

# Generate predictions
with torch.no_grad():
    initial_state = next(iter(val_loader))['input'].to(device)
    times = torch.linspace(0, 1, 10, device=device)
    predictions = ode_solver(initial_state, times)

print(f"Predictions shape: {predictions.shape}")  # [time, batch, channels, lat, lon]
```

## Additional Resources

- Main README: README.md
- Advanced Usage Guide: docs/advanced_usage.md
- Troubleshooting Guide: docs/troubleshooting.md
- API Documentation: docs/api/training.md
- ERA5 Data Tutorial: docs/tutorials/era5.md
- Foundation Model: foundation_model/README.md

## Getting Help

If you encounter issues:
1. Check the troubleshooting guide: [docs/troubleshooting.md](docs/troubleshooting.md)
2. Review the examples in `examples/` directory
3. Read the documentation in `docs/` directory
4. File an issue on GitHub: https://github.com/monksealseal/weatherflow/issues
