Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

from builtins import zip 

from builtins import str 

import logging 

 

from airflow.utils import AirflowException 

from airflow.hooks import PrestoHook 

from airflow.models import BaseOperator 

from airflow.utils import apply_defaults 

 

 

class CheckOperator(BaseOperator): 

    """ 

    Performs checks against a db. The ``CheckOperator`` expects 

    a sql query that will return a single row. Each value on that 

    first row is evaluated using python ``bool`` casting. If any of the 

    values return ``False`` the check is failed and errors out. 

 

    Note that Python bool casting evals the following as ``False``: 

    * False 

    * 0 

    * Empty string (``""``) 

    * Empty list (``[]``) 

    * Empty dictionary or set (``{}``) 

 

    Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if 

    the count ``== 0``. You can craft much more complex query that could, 

    for instance, check that the table has the same number of rows as 

    the source table upstream, or that the count of today's partition is 

    greater than yesterday's partition, or that a set of metrics are less 

    than 3 standard deviation for the 7 day average. 

 

    This operator can be used as a data quality check in your pipeline, and 

    depending on where you put it in your DAG, you have the choice to 

    stop the critical path, preventing from 

    publishing dubious data, or on the side and receive email alerts 

    without stopping the progress of the DAG. 

 

    Note that this is an abstract class and get_db_hook 

    needs to be defined. Whereas a get_db_hook is hook that gets a 

    single record from an external source. 

 

    :param sql: the sql to be executed 

    :type sql: string 

    """ 

 

    template_fields = ('sql',) 

    template_ext = ('.hql', '.sql',) 

    ui_color = '#fff7e6' 

 

    @apply_defaults 

    def __init__( 

            self, sql, 

            *args, **kwargs): 

        super(CheckOperator, self).__init__(*args, **kwargs) 

 

        self.sql = sql 

 

    def execute(self, context=None): 

        logging.info('Executing SQL check: ' + self.sql) 

        records = self.get_db_hook().get_first(self.sql) 

        logging.info("Record: " + str(records)) 

        if not records: 

            raise AirflowException("The query returned None") 

        elif not all([bool(r) for r in records]): 

            exceptstr = "Test failed.\nQuery:\n{q}\nResults:\n{r!s}" 

            raise AirflowException(exceptstr.format(q=self.sql, r=records)) 

        logging.info("Success.") 

 

    def get_db_hook(self): 

        """ 

        Requires that the hook has a ``get_first`` method receiving sql 

        and returning a tuple. 

        """ 

        raise NotImplemented() 

 

 

def _convert_to_float_if_possible(s): 

    ''' 

    A small helper function to convert a string to a numeric value 

    if appropriate 

 

    :param s: the string to be converted 

    :type s: str 

    ''' 

    try: 

        ret = float(s) 

    except (ValueError, TypeError): 

        ret = s 

    return ret 

 

 

class ValueCheckOperator(BaseOperator): 

    """ 

    Performs a simple value check using sql code. 

 

    Note that this is an abstract class and get_db_hook 

    needs to be defined. Whereas a get_db_hook is hook that gets a 

    single record from an external source. 

 

    :param sql: the sql to be executed 

    :type sql: string 

    """ 

 

    __mapper_args__ = { 

        'polymorphic_identity': 'ValueCheckOperator' 

    } 

    template_fields = ('sql',) 

    template_ext = ('.hql', '.sql',) 

    ui_color = '#fff7e6' 

 

    @apply_defaults 

    def __init__( 

            self, sql, pass_value, tolerance=None, 

            *args, **kwargs): 

        super(ValueCheckOperator, self).__init__(*args, **kwargs) 

        self.sql = sql 

        self.pass_value = _convert_to_float_if_possible(pass_value) 

        tol = _convert_to_float_if_possible(tolerance) 

        self.tol = tol if isinstance(tol, float) else None 

        self.is_numeric_value_check = isinstance(self.pass_value, float) 

        self.has_tolerance = self.tol is not None 

 

    def execute(self, context=None): 

        logging.info('Executing SQL check: ' + self.sql) 

        records = self.get_db_hook().get_first(hql=self.sql) 

        if not records: 

            raise AirflowException("The query returned None") 

        test_results = [] 

        except_temp = ("Test failed.\nPass value:{self.pass_value}\n" 

                       "Query:\n{self.sql}\nResults:\n{records!s}") 

        if not self.is_numeric_value_check: 

            tests = [str(r) == self.pass_value for r in records] 

        elif self.is_numeric_value_check: 

            try: 

                num_rec = [float(r) for r in records] 

            except (ValueError, TypeError) as e: 

                cvestr = "Converting a result to float failed.\n" 

                raise AirflowException(cvestr+except_temp.format(**locals())) 

            if self.has_tolerance: 

                tests = [ 

                    r / (1 + self.tol) <= self.pass_value <= r / (1 - self.tol) 

                    for r in num_rec] 

            else: 

                tests = [r == self.pass_value for r in num_rec] 

        if not all(tests): 

            raise AirflowException(except_temp.format(**locals())) 

 

    def get_db_hook(self): 

        raise NotImplemented() 

 

 

class IntervalCheckOperator(BaseOperator): 

    """ 

    Checks that the values of metrics given as SQL expressions are within 

    a certain tolerance of the ones from days_back before. 

 

    Note that this is an abstract class and get_db_hook 

    needs to be defined. Whereas a get_db_hook is hook that gets a 

    single record from an external source. 

 

    :param table: the table name 

    :type table: str 

    :param days_back: number of days between ds and the ds we want to check 

        against. Defaults to 7 days 

    :type days_back: int 

    :param metrics_threshold: a dictionary of ratios indexed by metrics 

    :type metrics_threshold: dict 

    """ 

 

    __mapper_args__ = { 

        'polymorphic_identity': 'IntervalCheckOperator' 

    } 

    template_fields = ('sql1', 'sql2') 

    template_ext = ('.hql', '.sql',) 

    ui_color = '#fff7e6' 

 

    @apply_defaults 

    def __init__( 

            self, table, metrics_thresholds, 

            date_filter_column='ds', days_back=-7, 

            *args, **kwargs): 

        super(IntervalCheckOperator, self).__init__(*args, **kwargs) 

        self.table = table 

        self.metrics_thresholds = metrics_thresholds 

        self.metrics_sorted = sorted(metrics_thresholds.keys()) 

        self.date_filter_column = date_filter_column 

        self.days_back = -abs(days_back) 

        sqlexp = ', '.join(self.metrics_sorted) 

        sqlt = ("SELECT {sqlexp} FROM {table}" 

                " WHERE {date_filter_column}=").format(**locals()) 

        self.sql1 = sqlt + "'{{ ds }}'" 

        self.sql2 = sqlt + "'{{ macros.ds_add(ds, "+str(self.days_back)+") }}'" 

 

    def execute(self, context=None): 

        hook = self.get_db_hook() 

        logging.info('Executing SQL check: ' + self.sql2) 

        row2 = hook.get_first(hql=self.sql2) 

        logging.info('Executing SQL check: ' + self.sql1) 

        row1 = hook.get_first(hql=self.sql1) 

        if not row2: 

            raise AirflowException("The query {q} returned None").format(q=self.sql2) 

        if not row1: 

            raise AirflowException("The query {q} returned None").format(q=self.sql1) 

        current = dict(zip(self.metrics_sorted, row1)) 

        reference = dict(zip(self.metrics_sorted, row2)) 

        ratios = {} 

        test_results = {} 

        rlog = "Ratio for {0}: {1} \n Ratio threshold : {2}" 

        fstr = "'{k}' check failed. {r} is above {tr}" 

        estr = "The following tests have failed:\n {0}" 

        countstr = "The following {j} tests out of {n} failed:" 

        for m in self.metrics_sorted: 

            if current[m] == 0 or reference[m] == 0: 

                ratio = None 

            else: 

                ratio = float(max(current[m], reference[m])) / \ 

                    min(current[m], reference[m]) 

            logging.info(rlog.format(m, ratio, self.metrics_thresholds[m])) 

            ratios[m] = ratio 

            test_results[m] = ratio < self.metrics_thresholds[m] 

        if not all(test_results.values()): 

            failed_tests = [it[0] for it in test_results.items() if not it[1]] 

            j = len(failed_tests) 

            n = len(self.metrics_sorted) 

            logging.warning(countstr.format(**locals())) 

            for k in failed_tests: 

                logging.warning(fstr.format(k=k, r=ratios[k], 

                                tr=self.metrics_thresholds[k])) 

            raise AirflowException(estr.format(", ".join(failed_tests))) 

        logging.info("All tests have passed") 

 

    def get_db_hook(self): 

        raise NotImplemented()