Coverage for tests/test_classify.py: 94%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 15:57 -0500

1import pytest 

2import mall 

3import polars as pl 

4import pyarrow 

5import shutil 

6import os 

7 

8if os._exists("_test_cache"): 

9 shutil.rmtree("_test_cache", ignore_errors=True) 

10 

11 

12def test_classify(): 

13 df = pl.DataFrame(dict(x=["one", "two", "three"])) 

14 df.llm.use("test", "echo", _cache="_test_cache") 

15 x = df.llm.classify("x", ["one", "two"]) 

16 assert ( 

17 x.select("classify").to_pandas().to_string() 

18 == ' classify\n0 one\n1 two\n2 None' 

19 ) 

20 

21def test_classify_dict(): 

22 df = pl.DataFrame(dict(x=[1,2,3])) 

23 df.llm.use("test", "echo", _cache="_test_cache") 

24 x = df.llm.classify("x", {"one": 1, "two": 2}) 

25 assert ( 

26 x.select("classify").to_pandas().to_string() 

27 == ' classify\n0 1.0\n1 2.0\n2 NaN' 

28 ) 

29