Coverage for mall/llm.py: 96%
93 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 16:00 -0500
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 16:00 -0500
1import polars as pl
2import ollama
3import json
4import hashlib
5import os
8def map_call(df, col, msg, pred_name, use, valid_resps="", convert=None):
9 if valid_resps == "":
10 valid_resps = []
11 valid_resps = valid_output(valid_resps)
12 ints = 0
13 for resp in valid_resps:
14 ints = ints + isinstance(resp, int)
16 pl_type = pl.String
17 data_type = str
19 if len(valid_resps) == ints & ints != 0:
20 pl_type = pl.Int8
21 data_type = int
23 df = df.with_columns(
24 pl.col(col)
25 .map_elements(
26 lambda x: llm_call(
27 x=x,
28 msg=msg,
29 use=use,
30 preview=False,
31 valid_resps=valid_resps,
32 convert=convert,
33 data_type=data_type,
34 ),
35 return_dtype=pl_type,
36 )
37 .alias(pred_name)
38 )
39 return df
42def llm_call(x, msg, use, preview=False, valid_resps="", convert=None, data_type=None):
44 backend = use.get("backend")
45 model=use.get("model")
46 call = dict(
47 backend=backend,
48 model=model,
49 messages=build_msg(x, msg),
50 options=use.get("options"),
51 )
53 if preview:
54 print(call)
56 cache = ""
57 if use.get("_cache") != "":
59 hash_call = build_hash(call)
60 cache = cache_check(hash_call, use)
62 if cache == "":
63 if backend == "ollama":
64 resp = ollama.chat(
65 model=use.get("model"),
66 messages=build_msg(x, msg),
67 options=use.get("options"),
68 )
69 out = resp["message"]["content"]
70 if backend == "test":
71 if model=="echo":
72 out = x
73 if model=="content":
74 out = msg[0]["content"]
75 return(out)
76 else:
77 out = cache
79 if use.get("_cache") != "":
80 if cache == "":
81 cache_record(hash_call, use, call, out)
83 if isinstance(convert, dict):
84 for label in convert:
85 if out == label:
86 out = convert.get(label)
88 if data_type == int:
89 out = data_type(out)
91 if out not in valid_resps and len(valid_resps) > 0:
92 out = None
94 return out
97def valid_output(x):
98 out = []
99 if isinstance(x, list):
100 out = x
101 if isinstance(x, dict):
102 for i in x:
103 out.append(x.get(i))
104 return out
107def build_msg(x, msg):
108 out = []
109 for msgs in msg:
110 out.append({"role": msgs["role"], "content": msgs["content"].format(x)})
111 return out
114def build_hash(x):
115 if isinstance(x, dict):
116 x = json.dumps(x)
117 x_sha = hashlib.sha1(x.encode("utf-8"))
118 x_digest = x_sha.hexdigest()
119 return x_digest
122def cache_check(hash_call, use):
123 file_path = cache_path(hash_call, use)
124 if os.path.isfile(file_path):
125 file_connection = open(file_path)
126 file_read = file_connection.read()
127 file_parse = json.loads(file_read)
128 out = file_parse.get("response")
129 else:
130 out = ""
131 return out
134def cache_record(hash_call, use, call, response):
135 file_path = cache_path(hash_call, use)
136 file_folder = os.path.dirname(file_path)
137 if not os.path.isdir(file_folder):
138 os.makedirs(file_folder)
139 contents = dict(request=call, response=response)
140 json_contents = json.dumps(contents)
141 with open(file_path, "w") as file:
142 file.write(json_contents)
145def cache_path(hash_call, use):
146 sub_folder = hash_call[0:2]
147 file_path = use.get("_cache") + "/" + sub_folder + "/" + hash_call + ".json"
148 return file_path