Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

# -*- coding: utf-8 -*- 

# 

# Licensed under the Apache License, Version 2.0 (the "License"); 

# you may not use this file except in compliance with the License. 

# You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

 

from builtins import str 

from datetime import datetime 

import logging 

 

from airflow.models import BaseOperator, TaskInstance 

from airflow.utils.state import State 

from airflow.utils.decorators import apply_defaults 

from airflow import settings 

 

 

class PythonOperator(BaseOperator): 

""" 

Executes a Python callable 

 

:param python_callable: A reference to an object that is callable 

:type python_callable: python callable 

:param op_kwargs: a dictionary of keyword arguments that will get unpacked 

in your function 

:type op_kwargs: dict 

:param op_args: a list of positional arguments that will get unpacked when 

calling your callable 

:type op_args: list 

:param provide_context: if set to true, Airflow will pass a set of 

keyword arguments that can be used in your function. This set of 

kwargs correspond exactly to what you can use in your jinja 

templates. For this to work, you need to define `**kwargs` in your 

function header. 

:type provide_context: bool 

:param templates_dict: a dictionary where the values are templates that 

will get templated by the Airflow engine sometime between 

``__init__`` and ``execute`` takes place and are made available 

in your callable's context after the template has been applied 

:type templates_dict: dict of str 

:param templates_exts: a list of file extensions to resolve while 

processing templated fields, for examples ``['.sql', '.hql']`` 

""" 

template_fields = ('templates_dict',) 

template_ext = tuple() 

ui_color = '#ffefeb' 

 

@apply_defaults 

def __init__( 

self, 

python_callable, 

op_args=None, 

op_kwargs=None, 

provide_context=False, 

templates_dict=None, 

templates_exts=None, 

*args, **kwargs): 

super(PythonOperator, self).__init__(*args, **kwargs) 

self.python_callable = python_callable 

self.op_args = op_args or [] 

self.op_kwargs = op_kwargs or {} 

self.provide_context = provide_context 

self.templates_dict = templates_dict 

if templates_exts: 

self.template_ext = templates_exts 

 

def execute(self, context): 

if self.provide_context: 

context.update(self.op_kwargs) 

context['templates_dict'] = self.templates_dict 

self.op_kwargs = context 

 

return_value = self.python_callable(*self.op_args, **self.op_kwargs) 

logging.info("Done. Returned value was: " + str(return_value)) 

return return_value 

 

 

class BranchPythonOperator(PythonOperator): 

""" 

Allows a workflow to "branch" or follow a single path following the 

execution of this task. 

 

It derives the PythonOperator and expects a Python function that returns 

the task_id to follow. The task_id returned should point to a task 

directly downstream from {self}. All other "branches" or 

directly downstream tasks are marked with a state of ``skipped`` so that 

these paths can't move forward. The ``skipped`` states are propageted 

downstream to allow for the DAG state to fill up and the DAG run's state 

to be inferred. 

 

Note that using tasks with ``depends_on_past=True`` downstream from 

``BranchPythonOperator`` is logically unsound as ``skipped`` status 

will invariably lead to block tasks that depend on their past successes. 

``skipped`` states propagates where all directly upstream tasks are 

``skipped``. 

""" 

def execute(self, context): 

branch = super(BranchPythonOperator, self).execute(context) 

logging.info("Following branch " + branch) 

logging.info("Marking other directly downstream tasks as skipped") 

session = settings.Session() 

for task in context['task'].downstream_list: 

if task.task_id != branch: 

ti = TaskInstance( 

task, execution_date=context['ti'].execution_date) 

ti.state = State.SKIPPED 

ti.start_date = datetime.now() 

ti.end_date = datetime.now() 

session.merge(ti) 

session.commit() 

session.close() 

logging.info("Done.") 

 

 

class ShortCircuitOperator(PythonOperator): 

""" 

Allows a workflow to continue only if a condition is met. Otherwise, the 

workflow "short-circuits" and downstream tasks are skipped. 

 

The ShortCircuitOperator is derived from the PythonOperator. It evaluates a 

condition and short-circuits the workflow if the condition is False. Any 

downstream tasks are marked with a state of "skipped". If the condition is 

True, downstream tasks proceed as normal. 

 

The condition is determined by the result of `python_callable`. 

""" 

def execute(self, context): 

condition = super(ShortCircuitOperator, self).execute(context) 

logging.info("Condition result is {}".format(condition)) 

if condition: 

logging.info('Proceeding with downstream tasks...') 

return 

else: 

logging.info('Skipping downstream tasks...') 

session = settings.Session() 

for task in context['task'].downstream_list: 

ti = TaskInstance( 

task, execution_date=context['ti'].execution_date) 

ti.state = State.SKIPPED 

ti.start_date = datetime.now() 

ti.end_date = datetime.now() 

session.merge(ti) 

session.commit() 

session.close() 

logging.info("Done.")