Coverage for tests/test_classify.py: 94%
18 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 15:57 -0500
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 15:57 -0500
1import pytest
2import mall
3import polars as pl
4import pyarrow
5import shutil
6import os
8if os._exists("_test_cache"):
9 shutil.rmtree("_test_cache", ignore_errors=True)
12def test_classify():
13 df = pl.DataFrame(dict(x=["one", "two", "three"]))
14 df.llm.use("test", "echo", _cache="_test_cache")
15 x = df.llm.classify("x", ["one", "two"])
16 assert (
17 x.select("classify").to_pandas().to_string()
18 == ' classify\n0 one\n1 two\n2 None'
19 )
21def test_classify_dict():
22 df = pl.DataFrame(dict(x=[1,2,3]))
23 df.llm.use("test", "echo", _cache="_test_cache")
24 x = df.llm.classify("x", {"one": 1, "two": 2})
25 assert (
26 x.select("classify").to_pandas().to_string()
27 == ' classify\n0 1.0\n1 2.0\n2 NaN'
28 )