Coverage for mall/llm.py: 96%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 16:00 -0500

1import polars as pl 

2import ollama 

3import json 

4import hashlib 

5import os 

6 

7 

8def map_call(df, col, msg, pred_name, use, valid_resps="", convert=None): 

9 if valid_resps == "": 

10 valid_resps = [] 

11 valid_resps = valid_output(valid_resps) 

12 ints = 0 

13 for resp in valid_resps: 

14 ints = ints + isinstance(resp, int) 

15 

16 pl_type = pl.String 

17 data_type = str 

18 

19 if len(valid_resps) == ints & ints != 0: 

20 pl_type = pl.Int8 

21 data_type = int 

22 

23 df = df.with_columns( 

24 pl.col(col) 

25 .map_elements( 

26 lambda x: llm_call( 

27 x=x, 

28 msg=msg, 

29 use=use, 

30 preview=False, 

31 valid_resps=valid_resps, 

32 convert=convert, 

33 data_type=data_type, 

34 ), 

35 return_dtype=pl_type, 

36 ) 

37 .alias(pred_name) 

38 ) 

39 return df 

40 

41 

42def llm_call(x, msg, use, preview=False, valid_resps="", convert=None, data_type=None): 

43 

44 backend = use.get("backend") 

45 model=use.get("model") 

46 call = dict( 

47 backend=backend, 

48 model=model, 

49 messages=build_msg(x, msg), 

50 options=use.get("options"), 

51 ) 

52 

53 if preview: 

54 print(call) 

55 

56 cache = "" 

57 if use.get("_cache") != "": 

58 

59 hash_call = build_hash(call) 

60 cache = cache_check(hash_call, use) 

61 

62 if cache == "": 

63 if backend == "ollama": 

64 resp = ollama.chat( 

65 model=use.get("model"), 

66 messages=build_msg(x, msg), 

67 options=use.get("options"), 

68 ) 

69 out = resp["message"]["content"] 

70 if backend == "test": 

71 if model=="echo": 

72 out = x 

73 if model=="content": 

74 out = msg[0]["content"] 

75 return(out) 

76 else: 

77 out = cache 

78 

79 if use.get("_cache") != "": 

80 if cache == "": 

81 cache_record(hash_call, use, call, out) 

82 

83 if isinstance(convert, dict): 

84 for label in convert: 

85 if out == label: 

86 out = convert.get(label) 

87 

88 if data_type == int: 

89 out = data_type(out) 

90 

91 if out not in valid_resps and len(valid_resps) > 0: 

92 out = None 

93 

94 return out 

95 

96 

97def valid_output(x): 

98 out = [] 

99 if isinstance(x, list): 

100 out = x 

101 if isinstance(x, dict): 

102 for i in x: 

103 out.append(x.get(i)) 

104 return out 

105 

106 

107def build_msg(x, msg): 

108 out = [] 

109 for msgs in msg: 

110 out.append({"role": msgs["role"], "content": msgs["content"].format(x)}) 

111 return out 

112 

113 

114def build_hash(x): 

115 if isinstance(x, dict): 

116 x = json.dumps(x) 

117 x_sha = hashlib.sha1(x.encode("utf-8")) 

118 x_digest = x_sha.hexdigest() 

119 return x_digest 

120 

121 

122def cache_check(hash_call, use): 

123 file_path = cache_path(hash_call, use) 

124 if os.path.isfile(file_path): 

125 file_connection = open(file_path) 

126 file_read = file_connection.read() 

127 file_parse = json.loads(file_read) 

128 out = file_parse.get("response") 

129 else: 

130 out = "" 

131 return out 

132 

133 

134def cache_record(hash_call, use, call, response): 

135 file_path = cache_path(hash_call, use) 

136 file_folder = os.path.dirname(file_path) 

137 if not os.path.isdir(file_folder): 

138 os.makedirs(file_folder) 

139 contents = dict(request=call, response=response) 

140 json_contents = json.dumps(contents) 

141 with open(file_path, "w") as file: 

142 file.write(json_contents) 

143 

144 

145def cache_path(hash_call, use): 

146 sub_folder = hash_call[0:2] 

147 file_path = use.get("_cache") + "/" + sub_folder + "/" + hash_call + ".json" 

148 return file_path