Metadata-Version: 2.3
Name: kithara
Version: 0.0.5
Summary: LLM post-training library
Author: Kithara Authors
Requires-Python: >= 3.11
Description-Content-Type: text/markdown
Requires-Dist: flax>=0.7.0
Requires-Dist: datasets
Requires-Dist: huggingface-hub
Requires-Dist: keras>=3.8.0
Requires-Dist: transformers>=4.45.1
Requires-Dist: keras-hub>=0.18.1
Requires-Dist: google-api-python-client
Requires-Dist: google-auth-httplib2
Requires-Dist: google-auth-oauthlib
Requires-Dist: ray[default]==2.40.0
Requires-Dist: jax[cpu]
Requires-Dist: peft
Requires-Dist: hf_transfer
Requires-Dist: tabulate
Requires-Dist: aqtp
Requires-Dist: grain-nightly
Requires-Dist: orbax-checkpoint>=0.10.3
Requires-Dist: google-cloud-logging
Requires-Dist: tensorboardx
Requires-Dist: ml-collections
Requires-Dist: tensorflow_datasets
Requires-Dist: sentencepiece
Requires-Dist: tiktoken
Requires-Dist: cloud-accelerator-diagnostics
Requires-Dist: cloud-tpu-diagnostics
Requires-Dist: ml-goodput-measurement
Requires-Dist: google-cloud-monitoring
Requires-Dist: jax[cpu] ; extra == "cpu"
Requires-Dist: torch==2.4.0 ; extra == "cpu"
Requires-Dist: twine ; extra == "dev"
Requires-Dist: flit ; extra == "dev"
Requires-Dist: sphinx==7.1.2 ; extra == "dev"
Requires-Dist: sphinx-autobuild ; extra == "dev"
Requires-Dist: sphinx-rtd-theme ; extra == "dev"
Requires-Dist: jax[cuda] ; extra == "gpu"
Requires-Dist: torch==2.4.0 ; extra == "gpu"
Requires-Dist: jax[tpu] ; extra == "tpu"
Requires-Dist: torch==2.4.0+cpu ; extra == "tpu"
Project-URL: Documentation, https://kithara.readthedocs.io/en/latest/index.html
Project-URL: Homepage, https://github.com/wenxindongwork/keras-tuner-alpha
Project-URL: Repository, https://github.com/wenxindongwork/keras-tuner-alpha
Provides-Extra: cpu
Provides-Extra: dev
Provides-Extra: gpu
Provides-Extra: tpu

# Kithara

A LLM Post-training Library for TPUs and GPUs. 

# Set up

Kithara requires `Python>=3.11`.

### On CPU 

``` 
pip install kithara[cpu] 
```

### On TPU 

``` 
pip install kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu 
```
### On GPU 

``` 
pip install kithara[gpu]
```

# Examples

## SFT with LoRA 

Example of LoRA finetuning gemma2-2b. This script runs on single-host and multi-host environments, on both TPUs and GPUs. For multi-host set up, we included a Ray guide in the next section. 

```
python kithara/examples/singlehost/sft_lora_example.py
```

## Full parameter finetuning

Example of training a MaxText model. 

```
python kithara/examples/singlehost/full_finetuning_example.py
```

## Multi-host examples

Following instructions in `ray/README.md` to set up a Ray Cluster for running multi-host workloads. Here are example of how to  run tuning tasks once your cluster has been set up.

First copy the example script in the `examples/multihost` folder to a new folder on your your local machine, let's call it `ray_workdir`.
Then, use the `kithara multihost` CLI` to run the script on your Ray Cluster. 


```
cd ray_workdir
kithara multihost sft_lora_example.py --hf-token your_token
```

Similarly, you can run the full parameter finetuning example using the following command

```
cd ray_workdir
kithara multihost full_finetuning_example.py --hf-token your_token
```

You can early-stop your job using 

```ray job stop ray_job_id```

# Troubleshooting

1. Disk OOM when loading HF model checkpoint 

    First try emptying your cache by running the following code on your Ray Cluster.

    ```
    import shutil
    shutil.rmtree("/home/ubuntu/.cache/huggingface/hub/", ignore_errors=True)
    shutil.rmtree("/home/ubuntu/.keras/models", ignore_errors=True)
   ```

    If you are using a single VM, the path may be different.

    ```
    import shutil
    shutil.rmtree("~/.cache/huggingface/hub/", ignore_errors=True)
    shutil.rmtree("~/.keras/models", ignore_errors=True)
    ```

    If emptying the cache still doesn't help, try attaching a disk to your VM and change HF cache directory using the environment variable `export HF_HOME=<your_new_cache_dir>`. 
    
    You may have to copy your HF token to this new cache directory with `cp .cache/huggingface/token <your_new_cache_dir>/token`. 

2. Permission denied error when uploading checkpoint to GCS 

    First verify your current authentication :

    ```
    gcloud auth list
    gsutil ls gs://your_bucket
    ```

    For your Python code, you likely need to ensure you're using the same credentials.

    ```
    gcloud auth application-default login
    ```

3. jaxlib.xla_extension.XlaRuntimeError errors

    Try uninstall and reinstalling `jax` and `jaxlib`

    ```
    pip uninstall jax jaxlib
    pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    ```


