Coverage for airflow/models.py : 25%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
|
# -*- coding: utf-8 -*- # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.
Column, Integer, String, DateTime, Text, Boolean, ForeignKey, PickleType, Index, Float)
as_tuple, is_container, is_in, validate_key, pprinttable)
FERNET = Fernet(configuration.get('core', 'FERNET_KEY').encode('utf-8')) ENCRYPTION_ON = True
LongText = LONGTEXT else:
# used by DAG context_managers
''' Clears a set of task instances, but makes sure the running ones get killed. ''' job_ids = [] for ti in tis: if ti.state == State.RUNNING: if ti.job_id: ti.state = State.SHUTDOWN job_ids.append(ti.job_id) # todo: this creates an issue with the webui tests #elif ti.state != State.REMOVED: # ti.state = State.NONE # session.merge(ti) else: session.delete(ti) if job_ids: from airflow.jobs import BaseJob as BJ for job in session.query(BJ).filter(BJ.id.in_(job_ids)).all(): job.state = State.SHUTDOWN if activate_dag_runs: execution_dates = {ti.execution_date for ti in tis} dag_ids = {ti.dag_id for ti in tis} drs = session.query(DagRun).filter( DagRun.dag_id.in_(dag_ids), DagRun.execution_date.in_(execution_dates), ).all() for dr in drs: dr.state = State.RUNNING dr.start_date = datetime.now()
""" A dagbag is a collection of dags, parsed out of a folder tree and has high level configuration settings, like what database to use as a backend and what executor to use to fire off tasks. This makes it easier to run distinct environments for say production and development, tests, or for different teams or security profiles. What would have been system level settings are now dagbag level so that one system can run multiple, independent settings sets.
:param dag_folder: the folder to scan to find DAGs :type dag_folder: str :param executor: the executor to use when executing task instances in this DagBag :param include_examples: whether to include the examples that ship with airflow or not :type include_examples: bool :param sync_to_db: whether to sync the properties of the DAGs to the metadata DB while finding them, typically should be done by the scheduler job only :type sync_to_db: bool """ self, dag_folder=None, executor=DEFAULT_EXECUTOR, include_examples=configuration.getboolean('core', 'LOAD_EXAMPLES'), sync_to_db=False):
dag_folder = dag_folder or DAGS_FOLDER self.logger.info("Filling up the DagBag from {}".format(dag_folder)) self.dag_folder = dag_folder self.dags = {} self.sync_to_db = sync_to_db self.file_last_changed = {} self.executor = executor self.import_errors = {} if include_examples: example_dag_folder = os.path.join( os.path.dirname(__file__), 'example_dags') self.collect_dags(example_dag_folder) self.collect_dags(dag_folder) if sync_to_db: self.deactivate_inactive_dags()
""" :return: the amount of dags contained in this dagbag """ return len(self.dags)
""" Gets the DAG out of the dictionary, and refreshes it if expired """ # If asking for a known subdag, we want to refresh the parent root_dag_id = dag_id if dag_id in self.dags: dag = self.dags[dag_id] if dag.is_subdag: root_dag_id = dag.parent_dag.dag_id
# If the root_dag_id is absent or expired orm_dag = DagModel.get_current(root_dag_id) if orm_dag and ( root_dag_id not in self.dags or ( orm_dag.last_expired and dag.last_loaded < orm_dag.last_expired ) ): # Reprocessing source file found_dags = self.process_file( filepath=orm_dag.fileloc, only_if_updated=False)
if found_dags and dag_id in [dag.dag_id for dag in found_dags]: return self.dags[dag_id] elif dag_id in self.dags: del self.dags[dag_id] return self.dags.get(dag_id)
""" Given a path to a python module or zip file, this method imports the module and look for dag objects within it. """ found_dags = []
# todo: raise exception? if not os.path.isfile(filepath): return found_dags
try: # This failed before in what may have been a git sync # race condition dttm = datetime.fromtimestamp(os.path.getmtime(filepath)) if only_if_updated \ and filepath in self.file_last_changed \ and dttm == self.file_last_changed[filepath]: return found_dags
except Exception as e: logging.exception(e) return found_dags
mods = [] if not zipfile.is_zipfile(filepath): if safe_mode and os.path.isfile(filepath): with open(filepath, 'rb') as f: content = f.read() if not all([s in content for s in (b'DAG', b'airflow')]): return found_dags
self.logger.debug("Importing {}".format(filepath)) org_mod_name, file_ext = os.path.splitext(os.path.split(filepath)[-1]) mod_name = 'unusual_prefix_' + org_mod_name
if mod_name in sys.modules: del sys.modules[mod_name]
with timeout(configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")): try: m = imp.load_source(mod_name, filepath) mods.append(m) except Exception as e: self.logger.exception("Failed to import: " + filepath) self.import_errors[filepath] = str(e) self.file_last_changed[filepath] = dttm
else: zip_file = zipfile.ZipFile(filepath) for mod in zip_file.infolist(): head, tail = os.path.split(mod.filename) mod_name, ext = os.path.splitext(mod.filename) if not head and (ext == '.py' or ext == '.pyc'): if mod_name == '__init__': self.logger.warning("Found __init__.{0} at root of {1}". format(ext, filepath))
if safe_mode: with zip_file.open(mod.filename) as zf: self.logger.debug("Reading {} from {}". format(mod.filename, filepath)) content = zf.read() if not all([s in content for s in (b'DAG', b'airflow')]): # todo: create ignore list return found_dags
if mod_name in sys.modules: del sys.modules[mod_name]
try: sys.path.insert(0, filepath) m = importlib.import_module(mod_name) mods.append(m) except Exception as e: self.logger.exception("Failed to import: " + filepath) self.import_errors[filepath] = str(e) self.file_last_changed[filepath] = dttm
for m in mods: for dag in list(m.__dict__.values()): if isinstance(dag, DAG): if not dag.full_filepath: dag.full_filepath = filepath dag.is_subdag = False dag.module_name = m.__name__ self.bag_dag(dag, parent_dag=dag, root_dag=dag) found_dags.append(dag) found_dags += dag.subdags
self.file_last_changed[filepath] = dttm return found_dags
def kill_zombies(self, session): """ Fails tasks that haven't had a heartbeat in too long """ from airflow.jobs import LocalTaskJob as LJ self.logger.info("Finding 'running' jobs without a recent heartbeat") TI = TaskInstance secs = ( configuration.getint('scheduler', 'job_heartbeat_sec') * 3) + 120 limit_dttm = datetime.now() - timedelta(seconds=secs) self.logger.info( "Failing jobs without heartbeat after {}".format(limit_dttm))
tis = ( session.query(TI) .join(LJ, TI.job_id == LJ.id) .filter(TI.state == State.RUNNING) .filter( or_( LJ.state != State.RUNNING, LJ.latest_heartbeat < limit_dttm, )) .all() )
for ti in tis: if ti and ti.dag_id in self.dags: dag = self.dags[ti.dag_id] if ti.task_id in dag.task_ids: task = dag.get_task(ti.task_id) ti.task = task ti.handle_failure("{} killed as zombie".format(ti)) self.logger.info( 'Marked zombie job {} as failed'.format(ti)) session.commit()
""" Adds the DAG into the bag, recurses into sub dags. """ self.dags[dag.dag_id] = dag dag.resolve_template_files() dag.last_loaded = datetime.now()
for task in dag.tasks: settings.policy(task)
if self.sync_to_db: session = settings.Session() orm_dag = session.query( DagModel).filter(DagModel.dag_id == dag.dag_id).first() if not orm_dag: orm_dag = DagModel(dag_id=dag.dag_id) orm_dag.fileloc = root_dag.full_filepath orm_dag.is_subdag = dag.is_subdag orm_dag.owners = root_dag.owner orm_dag.is_active = True session.merge(orm_dag) session.commit() session.close()
for subdag in dag.subdags: subdag.full_filepath = dag.full_filepath subdag.parent_dag = dag subdag.fileloc = root_dag.full_filepath subdag.is_subdag = True self.bag_dag(subdag, parent_dag=dag, root_dag=root_dag) self.logger.debug('Loaded DAG {dag}'.format(**locals()))
self, dag_folder=None, only_if_updated=True): """ Given a file path or a folder, this method looks for python modules, imports them and adds them to the dagbag collection.
Note that if a .airflowignore file is found while processing, the directory, it will behaves much like a .gitignore does, ignoring files that match any of the regex patterns specified in the file. """ start_dttm = datetime.now() dag_folder = dag_folder or self.dag_folder
# Used to store stats around DagBag processing stats = [] FileLoadStat = namedtuple( 'FileLoadStat', "file duration dag_num task_num dags") if os.path.isfile(dag_folder): self.process_file(dag_folder, only_if_updated=only_if_updated) elif os.path.isdir(dag_folder): patterns = [] for root, dirs, files in os.walk(dag_folder, followlinks=True): ignore_file = [f for f in files if f == '.airflowignore'] if ignore_file: f = open(os.path.join(root, ignore_file[0]), 'r') patterns += [p for p in f.read().split('\n') if p] f.close() for f in files: try: filepath = os.path.join(root, f) if not os.path.isfile(filepath): continue mod_name, file_ext = os.path.splitext( os.path.split(filepath)[-1]) if file_ext != '.py' and not zipfile.is_zipfile(filepath): continue if not any( [re.findall(p, filepath) for p in patterns]): ts = datetime.now() found_dags = self.process_file( filepath, only_if_updated=only_if_updated)
td = datetime.now() - ts td = td.total_seconds() + ( float(td.microseconds) / 1000000) stats.append(FileLoadStat( filepath.replace(dag_folder, ''), td, len(found_dags), sum([len(dag.tasks) for dag in found_dags]), str([dag.dag_id for dag in found_dags]), )) except Exception as e: logging.warning(e) Stats.gauge( 'collect_dags', (datetime.now() - start_dttm).total_seconds(), 1) Stats.gauge( 'dagbag_size', len(self.dags), 1) Stats.gauge( 'dagbag_import_errors', len(self.import_errors), 1) self.dagbag_stats = sorted( stats, key=lambda x: x.duration, reverse=True)
"""Prints a report around DagBag loading stats""" report = textwrap.dedent("""\n ------------------------------------------------------------------- DagBag loading stats for {dag_folder} ------------------------------------------------------------------- Number of DAGs: {dag_num} Total task number: {task_num} DagBag parsing time: {duration} {table} """) stats = self.dagbag_stats return report.format( dag_folder=self.dag_folder, duration=sum([o.duration for o in stats]), dag_num=sum([o.dag_num for o in stats]), task_num=sum([o.dag_num for o in stats]), table=pprinttable(stats), )
active_dag_ids = [dag.dag_id for dag in list(self.dags.values())] session = settings.Session() for dag in session.query( DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): dag.is_active = False session.merge(dag) session.commit() session.close()
session = settings.Session() dag_ids = [dp.dag_id for dp in session.query(DagModel).filter( DagModel.is_paused == True)] session.commit() session.close() return dag_ids
return self.username
return str(self.id)
return self.superuser
""" Placeholder to store information about different database instances connection information. The idea here is that scripts use references to database instances (conn_id) instead of hard coding hostname, logins and passwords when using operators or hooks. """
self, conn_id=None, conn_type=None, host=None, login=None, password=None, schema=None, port=None, extra=None, uri=None): self.conn_id = conn_id if uri: self.parse_from_uri(uri) else: self.conn_type = conn_type self.host = host self.login = login self.password = password self.schema = schema self.port = port self.extra = extra
temp_uri = urlparse(uri) hostname = temp_uri.hostname or '' if '%2f' in hostname: hostname = hostname.replace('%2f', '/').replace('%2F', '/') conn_type = temp_uri.scheme if conn_type == 'postgresql': conn_type = 'postgres' self.conn_type = conn_type self.host = hostname self.schema = temp_uri.path[1:] self.login = temp_uri.username self.password = temp_uri.password self.port = temp_uri.port
if self._password and self.is_encrypted: if not ENCRYPTION_ON: raise AirflowException( "Can't decrypt, configuration is missing") return FERNET.decrypt(bytes(self._password, 'utf-8')).decode() else: return self._password
if value: try: self._password = FERNET.encrypt(bytes(value, 'utf-8')).decode() self.is_encrypted = True except NameError: self._password = value self.is_encrypted = False
def password(cls): descriptor=property(cls.get_password, cls.set_password))
if self._extra and self.is_extra_encrypted: if not ENCRYPTION_ON: raise AirflowException( "Can't decrypt `extra`, configuration is missing") return FERNET.decrypt(bytes(self._extra, 'utf-8')).decode() else: return self._extra
if value: try: self._extra = FERNET.encrypt(bytes(value, 'utf-8')).decode() self.is_extra_encrypted = True except NameError: self._extra = value self.is_extra_encrypted = False
def extra(cls): descriptor=property(cls.get_extra, cls.set_extra))
try: if self.conn_type == 'mysql': from airflow.hooks.mysql_hook import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.contrib.hooks.bigquery_hook import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.hooks.postgres_hook import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.hooks.hive_hooks import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.hooks.presto_hook import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.hooks.hive_hooks import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.hooks.sqlite_hook import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.hooks.jdbc_hook import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.hooks.mssql_hook import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.hooks.oracle_hook import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.contrib.hooks.vertica_hook import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.contrib.hooks.cloudant_hook import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) except: return None
return self.conn_id
def extra_dejson(self): """Returns the extra property by deserializing json""" obj = {} if self.extra: try: obj = json.loads(self.extra) except Exception as e: logging.exception(e) logging.error( "Failed parsing the json for " "conn_id {}".format(self.conn_id)) return obj
""" Dags can originate from different places (user repos, master repo, ...) and also get executed in different places (different executors). This object represents a version of a DAG and becomes a source of truth for a BackfillJob execution. A pickle is a native python serialized object, and in this case gets stored in the database for the duration of the job.
The executors pick up the DagPickle id and read the dag definition from the database. """
self.dag_id = dag.dag_id if hasattr(dag, 'template_env'): dag.template_env = None self.pickle_hash = hash(dag) self.pickle = dag
""" Task instances store the state of a task instance. This table is the authority and single source of truth around what tasks have run and the state they are in.
The SqlAchemy model doesn't have a SqlAlchemy foreign key to the task or dag model deliberately to have more control over transactions.
Database transactions on this table should insure double triggers and any confusion around what task instances are or aren't ready to run even while multiple schedulers may be firing task instances. """
Index('ti_dag_state', dag_id, state), Index('ti_state_lkp', dag_id, task_id, execution_date, state), Index('ti_pool', pool, state, priority_weight), )
self.dag_id = task.dag_id self.task_id = task.task_id self.execution_date = execution_date self.task = task self.queue = task.queue self.pool = task.pool self.priority_weight = task.priority_weight_total self.try_number = 0 self.test_mode = False # can be changed when calling 'run' self.force = False # can be changed when calling 'run' self.unixname = getpass.getuser() if state: self.state = state
self, mark_success=False, ignore_dependencies=False, ignore_depends_on_past=False, force=False, local=False, pickle_id=None, raw=False, job_id=None, pool=None): """ Returns a command that can be executed anywhere where airflow is installed. This command is part of the message sent to executors by the orchestrator. """ dag = self.task.dag iso = self.execution_date.isoformat() cmd = "airflow run {self.dag_id} {self.task_id} {iso} " cmd += "--mark_success " if mark_success else "" cmd += "--pickle {pickle_id} " if pickle_id else "" cmd += "--job_id {job_id} " if job_id else "" cmd += "-i " if ignore_dependencies else "" cmd += "-I " if ignore_depends_on_past else "" cmd += "--force " if force else "" cmd += "--local " if local else "" cmd += "--pool {pool} " if pool else "" cmd += "--raw " if raw else "" if not pickle_id and dag: if dag.full_filepath != dag.filepath: cmd += "-sd DAGS_FOLDER/{dag.filepath} " elif dag.full_filepath: cmd += "-sd {dag.full_filepath}" return cmd.format(**locals())
def log_filepath(self): iso = self.execution_date.isoformat() log = os.path.expanduser(configuration.get('core', 'BASE_LOG_FOLDER')) return ( "{log}/{self.dag_id}/{self.task_id}/{iso}.log".format(**locals()))
def log_url(self): iso = self.execution_date.isoformat() BASE_URL = configuration.get('webserver', 'BASE_URL') return BASE_URL + ( "/admin/airflow/log" "?dag_id={self.dag_id}" "&task_id={self.task_id}" "&execution_date={iso}" ).format(**locals())
def mark_success_url(self): iso = self.execution_date.isoformat() BASE_URL = configuration.get('webserver', 'BASE_URL') return BASE_URL + ( "/admin/airflow/action" "?action=success" "&task_id={self.task_id}" "&dag_id={self.dag_id}" "&execution_date={iso}" "&upstream=false" "&downstream=false" ).format(**locals())
""" Get the very latest state from the database, if a session is passed, we use and looking up the state becomes part of the session, otherwise a new session is used. """ TI = TaskInstance ti = session.query(TI).filter( TI.dag_id == self.dag_id, TI.task_id == self.task_id, TI.execution_date == self.execution_date, ).all() if ti: state = ti[0].state else: state = None return state
""" Forces the task instance's state to FAILED in the database. """ logging.error("Recording the task instance as FAILED") self.state = State.FAILED session.merge(self) session.commit()
""" Refreshes the task instance from the database based on the primary key
:param lock_for_update: if True, indicates that the database should lock the TaskInstance (issuing a FOR UPDATE clause) until the session is committed. """ TI = TaskInstance
qry = session.query(TI).filter( TI.dag_id == self.dag_id, TI.task_id == self.task_id, TI.execution_date == self.execution_date)
if lock_for_update: ti = qry.with_for_update().first() else: ti = qry.first() if ti: self.state = ti.state self.start_date = ti.start_date self.end_date = ti.end_date self.try_number = ti.try_number else: self.state = None
""" Clears all XCom data from the database for the task instance """ session.query(XCom).filter( XCom.dag_id == self.dag_id, XCom.task_id == self.task_id, XCom.execution_date == self.execution_date ).delete() session.commit()
def key(self): """ Returns a tuple that identifies the task instance uniquely """ return (self.dag_id, self.task_id, self.execution_date)
self.state = state self.start_date = datetime.now() self.end_date = datetime.now() session.merge(self) session.commit()
self, include_queued=False, ignore_depends_on_past=False, flag_upstream_failed=False): """ Returns a boolean on whether the task instance has met all dependencies and is ready to run. It considers the task's state, the state of its dependencies, depends_on_past and makes sure the execution isn't in the future. It doesn't take into account whether the pool has a slot for it to run.
:param include_queued: If True, tasks that have already been queued are included. Defaults to False. :type include_queued: boolean :param ignore_depends_on_past: if True, ignores depends_on_past dependencies. Defaults to False. :type ignore_depends_on_past: boolean :param flag_upstream_failed: This is a hack to generate the upstream_failed state creation while checking to see whether the task instance is runnable. It was the shortest path to add the feature :type flag_upstream_failed: boolean """ # is the execution date in the future? if self.execution_date > datetime.now(): return False # is the task still in the retry waiting period? elif self.is_premature(): return False # does the task have an end_date prior to the execution date? elif self.task.end_date and self.execution_date > self.task.end_date: return False # has the task been skipped? elif self.state == State.SKIPPED: return False # has the task already been queued (and are we excluding queued tasks)? elif self.state == State.QUEUED and not include_queued: return False # is the task runnable and have its dependencies been met? elif ( self.state in State.runnable() and self.are_dependencies_met( ignore_depends_on_past=ignore_depends_on_past, flag_upstream_failed=flag_upstream_failed)): return True # anything else else: return False
""" Returns whether a task is in UP_FOR_RETRY state and its retry interval has elapsed. """ # is the task still in the retry waiting period? return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
self, include_queued=False, ignore_depends_on_past=False, flag_upstream_failed=False): """ Returns whether a task is ready to run AND there's room in the queue.
:param include_queued: If True, tasks that are already QUEUED are considered "runnable". Defaults to False. :type include_queued: boolean :param ignore_depends_on_past: if True, ignores depends_on_past dependencies. Defaults to False. :type ignore_depends_on_past: boolean """ queueable = self.is_queueable( include_queued=include_queued, ignore_depends_on_past=ignore_depends_on_past, flag_upstream_failed=flag_upstream_failed) return queueable and not self.pool_full()
""" Checks whether the dependents of this task instance have all succeeded. This is meant to be used by wait_for_downstream.
This is useful when you do not want to start processing the next schedule of a task until the dependents are done. For instance, if the task DROPs and recreates a table. """ task = self.task
if not task.downstream_task_ids: return True
ti = session.query(func.count(TaskInstance.task_id)).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(task.downstream_task_ids), TaskInstance.execution_date == self.execution_date, TaskInstance.state == State.SUCCESS, ) count = ti[0][0] return count == len(task.downstream_task_ids)
upstream_failed, done, flag_upstream_failed, session=None): """ Returns a boolean on whether the current task can be scheduled for execution based on its trigger_rule.
:param flag_upstream_failed: This is a hack to generate the upstream_failed state creation while checking to see whether the task instance is runnable. It was the shortest path to add the feature :type flag_upstream_failed: boolean :param successes: Number of successful upstream tasks :type successes: boolean :param skipped: Number of skipped upstream tasks :type skipped: boolean :param failed: Number of failed upstream tasks :type failed: boolean :param upstream_failed: Number of upstream_failed upstream tasks :type upstream_failed: boolean :param done: Number of completed upstream tasks :type done: boolean """ TR = TriggerRule
task = self.task upstream = len(task.upstream_task_ids) tr = task.trigger_rule upstream_done = done >= upstream
# handling instant state assignment based on trigger rules if flag_upstream_failed: if tr == TR.ALL_SUCCESS: if upstream_failed or failed: self.set_state(State.UPSTREAM_FAILED, session) elif skipped: self.set_state(State.SKIPPED, session) elif tr == TR.ALL_FAILED: if successes or skipped: self.set_state(State.SKIPPED, session) elif tr == TR.ONE_SUCCESS: if upstream_done and not successes: self.set_state(State.SKIPPED, session) elif tr == TR.ONE_FAILED: if upstream_done and not (failed or upstream_failed): self.set_state(State.SKIPPED, session)
return ( (tr == TR.ONE_SUCCESS and successes > 0) or (tr == TR.ONE_FAILED and (failed or upstream_failed)) or (tr == TR.ALL_SUCCESS and successes >= upstream) or (tr == TR.ALL_FAILED and failed + upstream_failed >= upstream) or (tr == TR.ALL_DONE and upstream_done) )
self, session=None, flag_upstream_failed=False, ignore_depends_on_past=False, verbose=False): """ Returns a boolean on whether the upstream tasks are in a SUCCESS state and considers depends_on_past and the previous run's state.
:param flag_upstream_failed: This is a hack to generate the upstream_failed state creation while checking to see whether the task instance is runnable. It was the shortest path to add the feature :type flag_upstream_failed: boolean :param ignore_depends_on_past: if True, ignores depends_on_past dependencies. Defaults to False. :type ignore_depends_on_past: boolean :param verbose: verbose provides more logging in the case where the task instance is evaluated as a check right before being executed. In the case of the scheduler evaluating the dependencies, this logging would be way too verbose. :type verbose: boolean """ TI = TaskInstance TR = TriggerRule
task = self.task
# Checking that the depends_on_past is fulfilled if (task.depends_on_past and not ignore_depends_on_past and not self.execution_date == task.start_date): previous_ti = session.query(TI).filter( TI.dag_id == self.dag_id, TI.task_id == task.task_id, TI.execution_date == self.task.dag.previous_schedule(self.execution_date), TI.state.in_({State.SUCCESS, State.SKIPPED}), ).first() if not previous_ti: if verbose: logging.warning("depends_on_past not satisfied") return False
# Applying wait_for_downstream previous_ti.task = self.task if task.wait_for_downstream and not \ previous_ti.are_dependents_done(session=session): if verbose: logging.warning("wait_for_downstream not satisfied") return False
# Checking that all upstream dependencies have succeeded if not task.upstream_list or task.trigger_rule == TR.DUMMY: return True
# todo: this query becomes quite expensive with dags that have # many tasks. It should be refactored to let the task report # to the dag run and get the aggregates from there qry = ( session .query( func.coalesce(func.sum( case([(TI.state == State.SUCCESS, 1)], else_=0)), 0), func.coalesce(func.sum( case([(TI.state == State.SKIPPED, 1)], else_=0)), 0), func.coalesce(func.sum( case([(TI.state == State.FAILED, 1)], else_=0)), 0), func.coalesce(func.sum( case([(TI.state == State.UPSTREAM_FAILED, 1)], else_=0)), 0), func.count(TI.task_id), ) .filter( TI.dag_id == self.dag_id, TI.task_id.in_(task.upstream_task_ids), TI.execution_date == self.execution_date, TI.state.in_([ State.SUCCESS, State.FAILED, State.UPSTREAM_FAILED, State.SKIPPED]), ) )
successes, skipped, failed, upstream_failed, done = qry.first()
satisfied = self.evaluate_trigger_rule( session=session, successes=successes, skipped=skipped, failed=failed, upstream_failed=upstream_failed, done=done, flag_upstream_failed=flag_upstream_failed)
if verbose and not satisfied: logging.warning("Trigger rule `{}` not satisfied".format(task.trigger_rule)) return satisfied
return ( "<TaskInstance: {ti.dag_id}.{ti.task_id} " "{ti.execution_date} [{ti.state}]>" ).format(ti=self)
""" Checks on whether the task instance is in the right state and timeframe to be retried. """ return self.state == State.UP_FOR_RETRY and \ self.end_date + self.task.retry_delay < datetime.now()
def pool_full(self, session): """ Returns a boolean as to whether the slot pool has room for this task to run """ if not self.task.pool: return False
pool = ( session .query(Pool) .filter(Pool.pool == self.task.pool) .first() ) if not pool: raise ValueError( "Task specified a pool ({}) but the pool " "doesn't exist!".format(self.task.pool)) open_slots = pool.open_slots(session=session)
return open_slots <= 0
self, verbose=True, ignore_dependencies=False, ignore_depends_on_past=False, force=False, mark_success=False, test_mode=False, job_id=None, pool=None, session=None): """ Runs the task instance.
:param verbose: whether to turn on more verbose loggin :type verbose: boolean :param ignore_dependencies: Doesn't check for deps, just runs :type ignore_dependencies: boolean :param ignore_depends_on_past: Ignore depends_on_past but respect other dependencies :type ignore_depends_on_past: boolean :param force: Forces a run regarless of previous success :type force: boolean :param mark_success: Don't run the task, mark its state as success :type mark_success: boolean :param test_mode: Doesn't record success or failure in the DB :type test_mode: boolean :param pool: specifies the pool to use to run the task instance :type pool: str """ task = self.task self.pool = pool or task.pool self.test_mode = test_mode self.force = force self.refresh_from_db(session=session, lock_for_update=True) self.job_id = job_id iso = datetime.now().isoformat() self.hostname = socket.getfqdn() self.operator = task.__class__.__name__
if self.state == State.RUNNING: logging.warning("Another instance is running, skipping.") elif self.state == State.REMOVED: self.clear_xcom_data() logging.debug("Task {} was removed from the dag".format(self)) elif not force and self.state == State.SUCCESS: logging.info( "Task {self} previously succeeded" " on {self.end_date}".format(**locals()) ) Stats.incr('previously_succeeded', 1, 1) elif ( not ignore_dependencies and not self.are_dependencies_met( session=session, ignore_depends_on_past=ignore_depends_on_past, verbose=True)): logging.warning("Dependencies not met yet") elif ( # todo: move this to the scheduler self.state == State.UP_FOR_RETRY and not self.ready_for_retry()): next_run = (self.end_date + task.retry_delay).isoformat() logging.info( "Not ready for retry yet. " + "Next run after {0}".format(next_run) ) elif force or self.state in State.runnable(): self.clear_xcom_data() HR = "\n" + ("-" * 80) + "\n" # Line break
# For reporting purposes, we report based on 1-indexed, # not 0-indexed lists (i.e. Attempt 1 instead of # Attempt 0 for the first attempt). msg = "Starting attempt {attempt} of {total}".format( attempt=self.try_number % (task.retries + 1) + 1, total=task.retries + 1) self.start_date = datetime.now()
if not mark_success and self.state != State.QUEUED and ( self.pool or self.task.dag.concurrency_reached): # If a pool is set for this task, marking the task instance # as QUEUED self.state = State.QUEUED msg = "Queuing attempt {attempt} of {total}".format( attempt=self.try_number % (task.retries + 1) + 1, total=task.retries + 1) logging.info(HR + msg + HR)
self.queued_dttm = datetime.now() session.merge(self) session.commit() logging.info("Queuing into pool {}".format(self.pool)) return
# print status message logging.info(HR + msg + HR) self.try_number += 1
if not test_mode: session.add(Log(State.RUNNING, self)) self.state = State.RUNNING self.end_date = None if not test_mode: session.merge(self) session.commit()
# Closing all pooled connections to prevent # "max number of connections reached" settings.engine.dispose() if verbose: if mark_success: msg = "Marking success for " else: msg = "Executing " msg += "{self.task} on {self.execution_date}"
context = {} try: logging.info(msg.format(self=self)) if not mark_success: context = self.get_template_context()
task_copy = copy.copy(task) self.task = task_copy
def signal_handler(signum, frame): '''Setting kill signal handler''' logging.error("Killing subprocess") task_copy.on_kill() raise AirflowException("Task received SIGTERM signal") signal.signal(signal.SIGTERM, signal_handler)
self.render_templates() task_copy.pre_execute(context=context)
# If a timout is specified for the task, make it fail # if it goes beyond result = None if task_copy.execution_timeout: with timeout(int( task_copy.execution_timeout.total_seconds())): result = task_copy.execute(context=context)
else: result = task_copy.execute(context=context)
# If the task returns a result, push an XCom containing it if result is not None: self.xcom_push(key=XCOM_RETURN_KEY, value=result)
task_copy.post_execute(context=context) self.state = State.SUCCESS except AirflowSkipException: self.state = State.SKIPPED except (Exception, KeyboardInterrupt) as e: self.handle_failure(e, test_mode, context) raise
# Recording SUCCESS self.end_date = datetime.now() self.set_duration() if not test_mode: session.add(Log(self.state, self)) session.merge(self) session.commit()
# Success callback try: if task.on_success_callback: task.on_success_callback(context) except Exception as e3: logging.error("Failed when executing success callback") logging.exception(e3)
session.commit()
task = self.task task_copy = copy.copy(task) self.task = task_copy
self.render_templates() task_copy.dry_run()
logging.exception(error) task = self.task session = settings.Session() self.end_date = datetime.now() self.set_duration() if not test_mode: session.add(Log(State.FAILED, self))
# Let's go deeper try: if task.retries and self.try_number % (task.retries + 1) != 0: self.state = State.UP_FOR_RETRY logging.info('Marking task as UP_FOR_RETRY') if task.email_on_retry and task.email: self.email_alert(error, is_retry=True) else: self.state = State.FAILED if task.retries: logging.info('All retries failed; marking task as FAILED') else: logging.info('Marking task as FAILED.') if task.email_on_failure and task.email: self.email_alert(error, is_retry=False) except Exception as e2: logging.error( 'Failed to send email to: ' + str(task.email)) logging.exception(e2)
# Handling callbacks pessimistically try: if self.state == State.UP_FOR_RETRY and task.on_retry_callback: task.on_retry_callback(context) if self.state == State.FAILED and task.on_failure_callback: task.on_failure_callback(context) except Exception as e3: logging.error("Failed at executing callback") logging.exception(e3)
if not test_mode: session.merge(self) session.commit() logging.error(str(error))
task = self.task from airflow import macros tables = None if 'tables' in task.params: tables = task.params['tables']
ds = self.execution_date.isoformat()[:10] ts = self.execution_date.isoformat() yesterday_ds = (self.execution_date - timedelta(1)).isoformat()[:10] tomorrow_ds = (self.execution_date + timedelta(1)).isoformat()[:10]
ds_nodash = ds.replace('-', '') ts_nodash = ts.replace('-', '').replace(':', '') yesterday_ds_nodash = yesterday_ds.replace('-', '') tomorrow_ds_nodash = tomorrow_ds.replace('-', '')
ti_key_str = "{task.dag_id}__{task.task_id}__{ds_nodash}" ti_key_str = ti_key_str.format(**locals())
params = {} run_id = '' dag_run = None if hasattr(task, 'dag'): if task.dag.params: params.update(task.dag.params) dag_run = ( session.query(DagRun) .filter_by( dag_id=task.dag.dag_id, execution_date=self.execution_date) .first() ) run_id = dag_run.run_id if dag_run else None session.expunge_all() session.commit()
if task.params: params.update(task.params)
class VariableAccessor: """ Wrapper around Variable. This way you can get variables in templates by using {var.variable_name}. """ def __init__(self): pass
def __getattr__(self, item): return Variable.get(item)
class VariableJsonAccessor: def __init__(self): pass
def __getattr__(self, item): return Variable.get(item, deserialize_json=True)
return { 'dag': task.dag, 'ds': ds, 'ds_nodash': ds_nodash, 'ts': ts, 'ts_nodash': ts_nodash, 'yesterday_ds': yesterday_ds, 'yesterday_ds_nodash': yesterday_ds_nodash, 'tomorrow_ds': tomorrow_ds, 'tomorrow_ds_nodash': tomorrow_ds_nodash, 'END_DATE': ds, 'end_date': ds, 'dag_run': dag_run, 'run_id': run_id, 'execution_date': self.execution_date, 'latest_date': ds, 'macros': macros, 'params': params, 'tables': tables, 'task': task, 'task_instance': self, 'ti': self, 'task_instance_key_str': ti_key_str, 'conf': configuration, 'test_mode': self.test_mode, 'var': { 'value': VariableAccessor(), 'json': VariableJsonAccessor() } }
task = self.task jinja_context = self.get_template_context() if hasattr(self, 'task') and hasattr(self.task, 'dag'): if self.task.dag.user_defined_macros: jinja_context.update( self.task.dag.user_defined_macros)
rt = self.task.render_template # shortcut to method for attr in task.__class__.template_fields: content = getattr(task, attr) if content: rendered_content = rt(attr, content, jinja_context) setattr(task, attr, rendered_content)
task = self.task title = "Airflow alert: {self}".format(**locals()) exception = str(exception).replace('\n', '<br>') try_ = task.retries + 1 body = ( "Try {self.try_number} out of {try_}<br>" "Exception:<br>{exception}<br>" "Log: <a href='{self.log_url}'>Link</a><br>" "Host: {self.hostname}<br>" "Log file: {self.log_filepath}<br>" "Mark success: <a href='{self.mark_success_url}'>Link</a><br>" ).format(**locals()) send_email(task.email, title, body)
if self.end_date and self.start_date: self.duration = (self.end_date - self.start_date).total_seconds() else: self.duration = None
self, key, value, execution_date=None): """ Make an XCom available for tasks to pull.
:param key: A key for the XCom :type key: string :param value: A value for the XCom. The value is pickled and stored in the database. :type value: any pickleable object :param execution_date: if provided, the XCom will not be visible until this date. This can be used, for example, to send a message to a task on a future date without it being immediately visible. :type execution_date: datetime """
if execution_date and execution_date < self.execution_date: raise ValueError( 'execution_date can not be in the past (current ' 'execution_date is {}; received {})'.format( self.execution_date, execution_date))
XCom.set( key=key, value=value, task_id=self.task_id, dag_id=self.dag_id, execution_date=execution_date or self.execution_date)
self, task_ids, dag_id=None, key=XCOM_RETURN_KEY, include_prior_dates=False): """ Pull XComs that optionally meet certain criteria.
The default value for `key` limits the search to XComs that were returned by other tasks (as opposed to those that were pushed manually). To remove this filter, pass key=None (or any desired value).
If a single task_id string is provided, the result is the value of the most recent matching XCom from that task_id. If multiple task_ids are provided, a tuple of matching values is returned. None is returned whenever no matches are found.
:param key: A key for the XCom. If provided, only XComs with matching keys will be returned. The default key is 'return_value', also available as a constant XCOM_RETURN_KEY. This key is automatically given to XComs returned by tasks (as opposed to being pushed manually). To remove the filter, pass key=None. :type key: string :param task_ids: Only XComs from tasks with matching ids will be pulled. Can pass None to remove the filter. :type task_ids: string or iterable of strings (representing task_ids) :param dag_id: If provided, only pulls XComs from this DAG. If None (default), the DAG of the calling task is used. :type dag_id: string :param include_prior_dates: If False, only XComs from the current execution_date are returned. If True, XComs from previous dates are returned as well. :type include_prior_dates: bool """
if dag_id is None: dag_id = self.dag_id
pull_fn = functools.partial( XCom.get_one, execution_date=self.execution_date, key=key, dag_id=dag_id, include_prior_dates=include_prior_dates)
if is_container(task_ids): return tuple(pull_fn(task_id=t) for t in task_ids) else: return pull_fn(task_id=task_ids)
""" Used to actively log events to the database """
self.dttm = datetime.now() self.event = event self.extra = extra
task_owner = None
if task_instance: self.dag_id = task_instance.dag_id self.task_id = task_instance.task_id self.execution_date = task_instance.execution_date task_owner = task_instance.task.owner
if 'task_id' in kwargs: self.task_id = kwargs['task_id'] if 'dag_id' in kwargs: self.dag_id = kwargs['dag_id'] if 'execution_date' in kwargs: if kwargs['execution_date']: self.execution_date = kwargs['execution_date']
self.owner = owner or task_owner
""" Abstract base class for all operators. Since operators create objects that become node in the dag, BaseOperator contains many recursive methods for dag crawling behavior. To derive this class, you are expected to override the constructor as well as the 'execute' method.
Operators derived from this task should perform or trigger certain tasks synchronously (wait for completion). Example of operators could be an operator the runs a Pig job (PigOperator), a sensor operator that waits for a partition to land in Hive (HiveSensorOperator), or one that moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these operators (tasks) target specific operations, running specific scripts, functions or data transfers.
This class is abstract and shouldn't be instantiated. Instantiating a class derived from this one results in the creation of a task object, which ultimately becomes a node in DAG objects. Task dependencies should be set by using the set_upstream and/or set_downstream methods.
Note that this class is derived from SQLAlchemy's Base class, which allows us to push metadata regarding tasks to the database. Deriving this classes needs to implement the polymorphic specificities documented in SQLAlchemy. This should become clear while reading the code for other operators.
:param task_id: a unique, meaningful id for the task :type task_id: string :param owner: the owner of the task, using the unix username is recommended :type owner: string :param retries: the number of retries that should be performed before failing the task :type retries: int :param retry_delay: delay between retries :type retry_delay: timedelta :param start_date: The ``start_date`` for the task, determines the ``execution_date`` for the first task instance. The best practice is to have the start_date rounded to your DAG's ``schedule_interval``. Daily jobs have their start_date some day at 00:00:00, hourly jobs have their start_date at 00:00 of a specific hour. Note that Airflow simply looks at the latest ``execution_date`` and adds the ``schedule_interval`` to determine the next ``execution_date``. It is also very important to note that different tasks' dependencies need to line up in time. If task A depends on task B and their start_date are offset in a way that their execution_date don't line up, A's dependencies will never be met. If you are looking to delay a task, for example running a daily task at 2AM, look into the ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using dynamic ``start_date`` and recommend using fixed ones. Read the FAQ entry about start_date for more information. :type start_date: datetime :param end_date: if specified, the scheduler won't go beyond this date :type end_date: datetime :param depends_on_past: when set to true, task instances will run sequentially while relying on the previous task's schedule to succeed. The task instance for the start_date is allowed to run. :type depends_on_past: bool :param wait_for_downstream: when set to true, an instance of task X will wait for tasks immediately downstream of the previous instance of task X to finish successfully before it runs. This is useful if the different instances of a task X alter the same asset, and this asset is used by tasks downstream of task X. Note that depends_on_past is forced to True wherever wait_for_downstream is used. :type wait_for_downstream: bool :param queue: which queue to target when running this job. Not all executors implement queue management, the CeleryExecutor does support targeting specific queues. :type queue: str :param dag: a reference to the dag the task is attached to (if any) :type dag: DAG :param priority_weight: priority weight of this task against other task. This allows the executor to trigger higher priority tasks before others when things get backed up. :type priority_weight: int :param pool: the slot pool this task should run in, slot pools are a way to limit concurrency for certain tasks :type pool: str :param sla: time by which the job is expected to succeed. Note that this represents the ``timedelta`` after the period is closed. For example if you set an SLA of 1 hour, the scheduler would send dan email soon after 1:00AM on the ``2016-01-02`` if the ``2016-01-01`` instance has not succeede yet. The scheduler pays special attention for jobs with an SLA and sends alert emails for sla misses. SLA misses are also recorded in the database for future reference. All tasks that share the same SLA time get bundled in a single email, sent soon after that time. SLA notification are sent once and only once for each task instance. :type sla: datetime.timedelta :param execution_timeout: max time allowed for the execution of this task instance, if it goes beyond it will raise and fail. :type execution_timeout: datetime.timedelta :param on_failure_callback: a function to be called when a task instance of this task fails. a context dictionary is passed as a single parameter to this function. Context contains references to related objects to the task instance and is documented under the macros section of the API. :type on_failure_callback: callable :param on_retry_callback: much like the ``on_failure_callback`` excepts that it is executed when retries occur. :param on_success_callback: much like the ``on_failure_callback`` excepts that it is executed when the task succeeds. :type on_success_callback: callable :param trigger_rule: defines the rule by which dependencies are applied for the task to get triggered. Options are: ``{ all_success | all_failed | all_done | one_success | one_failed | dummy}`` default is ``all_success``. Options can be set as string or using the constants defined in the static class ``airflow.utils.TriggerRule`` :type trigger_rule: str """
# For derived classes to define which fields will get jinjaified # Defines wich files extensions to look for in the templated fields # Defines the color in the UI
self, task_id, owner=configuration.get('operators', 'DEFAULT_OWNER'), email=None, email_on_retry=True, email_on_failure=True, retries=0, retry_delay=timedelta(seconds=300), start_date=None, end_date=None, schedule_interval=None, # not hooked as of now depends_on_past=False, wait_for_downstream=False, dag=None, params=None, default_args=None, adhoc=False, priority_weight=1, queue=configuration.get('celery', 'default_queue'), pool=None, sla=None, execution_timeout=None, on_failure_callback=None, on_success_callback=None, on_retry_callback=None, trigger_rule=TriggerRule.ALL_SUCCESS, *args, **kwargs):
if args or kwargs: # TODO remove *args and **kwargs in Airflow 2.0 warnings.warn( 'Invalid arguments were passed to {c}. Support for ' 'passing such arguments will be dropped in Airflow 2.0. ' 'Invalid arguments were:' '\n*args: {a}\n**kwargs: {k}'.format( c=self.__class__.__name__, a=args, k=kwargs), category=PendingDeprecationWarning )
validate_key(task_id) self.task_id = task_id self.owner = owner self.email = email self.email_on_retry = email_on_retry self.email_on_failure = email_on_failure self.start_date = start_date if start_date and not isinstance(start_date, datetime): logging.warning( "start_date for {} isn't datetime.datetime".format(self)) self.end_date = end_date if not TriggerRule.is_valid(trigger_rule): raise AirflowException( "The trigger_rule must be one of {all_triggers}," "'{d}.{t}'; received '{tr}'." .format(all_triggers=TriggerRule.all_triggers, d=dag.dag_id, t=task_id, tr = trigger_rule))
self.trigger_rule = trigger_rule self.depends_on_past = depends_on_past self.wait_for_downstream = wait_for_downstream if wait_for_downstream: self.depends_on_past = True
if schedule_interval: logging.warning( "schedule_interval is used for {}, though it has " "been deprecated as a task parameter, you need to " "specify it as a DAG parameter instead".format(self)) self._schedule_interval = schedule_interval self.retries = retries self.queue = queue self.pool = pool self.sla = sla self.execution_timeout = execution_timeout self.on_failure_callback = on_failure_callback self.on_success_callback = on_success_callback self.on_retry_callback = on_retry_callback if isinstance(retry_delay, timedelta): self.retry_delay = retry_delay else: logging.debug("retry_delay isn't timedelta object, assuming secs") self.retry_delay = timedelta(seconds=retry_delay) self.params = params or {} # Available in templates! self.adhoc = adhoc self.priority_weight = priority_weight
# Private attributes self._upstream_task_ids = [] self._downstream_task_ids = []
if not dag and _CONTEXT_MANAGER_DAG: dag = _CONTEXT_MANAGER_DAG if dag: self.dag = dag
self._comps = { 'task_id', 'dag_id', 'owner', 'email', 'email_on_retry', 'retry_delay', 'start_date', 'schedule_interval', 'depends_on_past', 'wait_for_downstream', 'adhoc', 'priority_weight', 'sla', 'execution_timeout', 'on_failure_callback', 'on_success_callback', 'on_retry_callback', }
return ( type(self) == type(other) and all(self.__dict__.get(c, None) == other.__dict__.get(c, None) for c in self._comps))
return not self == other
return self.task_id < other.task_id
hash_components = [type(self)] for c in self._comps: val = getattr(self, c, None) try: hash(val) hash_components.append(val) except TypeError: hash_components.append(repr(val)) return hash(tuple(hash_components))
# Composing Operators -----------------------------------------------
""" Implements Self >> Other == self.set_downstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator. """ if isinstance(other, DAG): # if this dag is already assigned, do nothing # otherwise, do normal dag assignment if not (self.has_dag() and self.dag is other): self.dag = other else: self.set_downstream(other) return other
""" Implements Self << Other == self.set_upstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator. """ if isinstance(other, DAG): # if this dag is already assigned, do nothing # otherwise, do normal dag assignment if not (self.has_dag() and self.dag is other): self.dag = other else: self.set_upstream(other) return other
""" Called for [DAG] >> [Operator] because DAGs don't have __rshift__ operators. """ self.__lshift__(other) return self
""" Called for [DAG] << [Operator] because DAGs don't have __lshift__ operators. """ self.__rshift__(other) return self
# /Composing Operators ---------------------------------------------
def dag(self): """ Returns the Operator's DAG if set, otherwise raises an error """ if self.has_dag(): return self._dag else: raise AirflowException( 'Operator {} has not been assigned to a DAG yet'.format(self))
def dag(self, dag): """ Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok. """ if not isinstance(dag, DAG): raise TypeError( 'Expected DAG; received {}'.format(dag.__class__.__name__)) elif self.has_dag() and self.dag is not dag: raise AirflowException( "The DAG assigned to {} can not be changed.".format(self)) elif self.task_id not in dag.task_dict: dag.add_task(self)
self._dag = dag
""" Returns True if the Operator has been assigned to a DAG. """ return getattr(self, '_dag', None) is not None
def dag_id(self): if self.has_dag(): return self.dag.dag_id else: return 'adhoc_' + self.owner
def schedule_interval(self): """ The schedule interval of the DAG always wins over individual tasks so that tasks within a DAG always line up. The task still needs a schedule_interval as it may not be attached to a DAG. """ if self.has_dag(): return self.dag._schedule_interval else: return self._schedule_interval
def priority_weight_total(self): return sum([ t.priority_weight for t in self.get_flat_relatives(upstream=False) ]) + self.priority_weight
""" This is triggered right before self.execute, it's mostly a hook for people deriving operators. """ pass
""" This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.
Refer to get_template_context for more context. """ raise NotImplementedError()
""" This is triggered right after self.execute, it's mostly a hook for people deriving operators. """ pass
''' Override this method to cleanup subprocesses when a task instance gets killed. Any use of the threading, subprocess or multiprocessing module within an operator needs to be cleaned up or it will leave ghost processes behind. ''' pass
""" Hack sorting double chained task lists by task_id to avoid hitting max_depth on deepcopy operations. """ sys.setrecursionlimit(5000) # TODO fix this in a better way cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result
for k, v in list(self.__dict__.items()): if k not in ('user_defined_macros', 'params'): setattr(result, k, copy.deepcopy(v, memo)) result.params = self.params if hasattr(self, 'user_defined_macros'): result.user_defined_macros = self.user_defined_macros return result
''' Renders a template from a field. If the field is a string, it will simply render the string and return the result. If it is a collection or nested set of collections, it will traverse the structure and render all strings in it. ''' rt = self.render_template if isinstance(content, six.string_types): result = jinja_env.from_string(content).render(**context) elif isinstance(content, (list, tuple)): result = [rt(attr, e, context) for e in content] elif isinstance(content, dict): result = { k: rt("{}[{}]".format(attr, k), v, context) for k, v in list(content.items())} else: param_type = type(content) msg = ( "Type '{param_type}' used for parameter '{attr}' is " "not supported for templating").format(**locals()) raise AirflowException(msg) return result
''' Renders a template either from a file or directly in a field, and returns the rendered result. ''' jinja_env = self.dag.get_template_env() \ if hasattr(self, 'dag') \ else jinja2.Environment(cache_size=0)
exts = self.__class__.template_ext if ( isinstance(content, six.string_types) and any([content.endswith(ext) for ext in exts])): return jinja_env.get_template(content).render(**context) else: return self.render_template_from_field(attr, content, context, jinja_env)
''' Hook that is triggered after the templated fields get replaced by their content. If you need your operator to alter the content of the file before the template is rendered, it should override this method to do so. ''' pass
# Getting the content of files for template_field / template_ext for attr in self.template_fields: content = getattr(self, attr) if (content and isinstance(content, six.string_types) and any([content.endswith(ext) for ext in self.template_ext])): env = self.dag.get_template_env() try: setattr(self, attr, env.loader.get_source(env, content)[0]) except Exception as e: logging.exception(e) self.prepare_template()
def upstream_list(self): """@property: list of tasks directly upstream""" return [self.dag.get_task(tid) for tid in self._upstream_task_ids]
def upstream_task_ids(self): return self._upstream_task_ids
def downstream_list(self): """@property: list of tasks directly downstream""" return [self.dag.get_task(tid) for tid in self._downstream_task_ids]
def downstream_task_ids(self): return self._downstream_task_ids
self, start_date=None, end_date=None, upstream=False, downstream=False): """ Clears the state of task instances associated with the task, following the parameters specified. """ session = settings.Session()
TI = TaskInstance qry = session.query(TI).filter(TI.dag_id == self.dag_id)
if start_date: qry = qry.filter(TI.execution_date >= start_date) if end_date: qry = qry.filter(TI.execution_date <= end_date)
tasks = [self.task_id]
if upstream: tasks += [ t.task_id for t in self.get_flat_relatives(upstream=True)]
if downstream: tasks += [ t.task_id for t in self.get_flat_relatives(upstream=False)]
qry = qry.filter(TI.task_id.in_(tasks))
count = qry.count() clear_task_instances(qry, session)
session.commit() session.close() return count
""" Get a set of task instance related to this task for a specific date range. """ TI = TaskInstance end_date = end_date or datetime.now() return session.query(TI).filter( TI.dag_id == self.dag_id, TI.task_id == self.task_id, TI.execution_date >= start_date, TI.execution_date <= end_date, ).order_by(TI.execution_date).all()
""" Get a flat list of relatives, either upstream or downstream. """ if not l: l = [] for t in self.get_direct_relatives(upstream): if not is_in(t, l): l.append(t) t.get_flat_relatives(upstream, l) return l
""" When invoked, this routine will raise an exception if a cycle is detected downstream from self. It is invoked when tasks are added to the DAG to detect cycles. """ if not task: task = self for t in self.get_direct_relatives(): if task is t: msg = "Cycle detected in DAG. Faulty task: {0}".format(task) raise AirflowException(msg) else: t.detect_downstream_cycle(task=task) return False
self, start_date=None, end_date=None, ignore_dependencies=False, ignore_first_depends_on_past=False, force=False, mark_success=False): """ Run a set of task instances for a date range. """ start_date = start_date or self.start_date end_date = end_date or self.end_date or datetime.now()
for dt in self.dag.date_range(start_date, end_date=end_date): TaskInstance(self, dt).run( mark_success=mark_success, ignore_dependencies=ignore_dependencies, ignore_depends_on_past=( dt == start_date and ignore_first_depends_on_past), force=force,)
logging.info('Dry run') for attr in self.template_fields: content = getattr(self, attr) if content and isinstance(content, six.string_types): logging.info('Rendering template for {0}'.format(attr)) logging.info(content)
""" Get the direct relatives to the current task, upstream or downstream. """ if upstream: return self.upstream_list else: return self.downstream_list
return "<Task({self.__class__.__name__}): {self.task_id}>".format( self=self)
def task_type(self): return self.__class__.__name__
if any([item is t for t in l]): raise AirflowException( 'Dependency {self}, {item} already registered' ''.format(**locals())) else: l.append(item)
try: task_list = list(task_or_task_list) except TypeError: task_list = [task_or_task_list]
for t in task_list: if not isinstance(t, BaseOperator): raise AirflowException( "Relationships can only be set between " "Operators; received {}".format(t.__class__.__name__))
# relationships can only be set if the tasks share a single DAG. Tasks # without a DAG are assigned to that DAG. dags = set(t.dag for t in [self] + task_list if t.has_dag())
if len(dags) > 1: raise AirflowException( 'Tried to set relationships between tasks in ' 'more than one DAG: {}'.format(dags)) elif len(dags) == 1: dag = list(dags)[0] else: raise AirflowException( "Tried to create relationships between tasks that don't have " "DAGs yet. Set the DAG for at least one " "task and try again: {}".format([self] + task_list))
if dag and not self.has_dag(): self.dag = dag
for task in task_list: if dag and not task.has_dag(): task.dag = dag if upstream: task.append_only_new(task._downstream_task_ids, self.task_id) self.append_only_new(self._upstream_task_ids, task.task_id) else: self.append_only_new(self._downstream_task_ids, task.task_id) task.append_only_new(task._upstream_task_ids, self.task_id)
self.detect_downstream_cycle()
""" Set a task, or a task task to be directly downstream from the current task. """ self._set_relatives(task_or_task_list, upstream=False)
""" Set a task, or a task task to be directly upstream from the current task. """ self._set_relatives(task_or_task_list, upstream=True)
self, context, key, value, execution_date=None): """ See TaskInstance.xcom_push() """ context['ti'].xcom_push( key=key, value=value, execution_date=execution_date)
self, context, task_ids, dag_id=None, key=XCOM_RETURN_KEY, include_prior_dates=None): """ See TaskInstance.xcom_pull() """ return context['ti'].xcom_pull( key=key, task_ids=task_ids, dag_id=dag_id, include_prior_dates=include_prior_dates)
""" These items are stored in the database for state related information """ # A DAG can be paused from the UI / DB # Set this default value of is_paused based on a configuration value! # Whether the DAG is a subdag # Whether that DAG was seen on the last DagBag load # Last time the scheduler started # Last time this DAG was pickled # When the DAG received a refreshed signal last, used to know when # we need to force refresh # Whether (one of) the scheduler is scheduling this DAG at the moment # Foreign key to the latest pickle_id # The location of the file containing the DAG object # String representing the owners
return "<DAG: {self.dag_id}>".format(self=self)
def get_current(cls, dag_id): session = settings.Session() obj = session.query(cls).filter(cls.dag_id == dag_id).first() session.expunge_all() session.commit() session.close() return obj
""" A dag (directed acyclic graph) is a collection of tasks with directional dependencies. A dag also has a schedule, a start end an end date (optional). For each schedule, (say daily or hourly), the DAG needs to run each individual tasks as their dependencies are met. Certain tasks have the property of depending on their own past, meaning that they can't run until their previous schedule (and upstream tasks) are completed.
DAGs essentially act as namespaces for tasks. A task_id can only be added once to a DAG.
:param dag_id: The id of the DAG :type dag_id: string :param schedule_interval: Defines how often that DAG runs, this timedelta object gets added to your latest task instance's execution_date to figure out the next schedule :type schedule_interval: datetime.timedelta or dateutil.relativedelta.relativedelta or str that acts as a cron expression :param start_date: The timestamp from which the scheduler will attempt to backfill :type start_date: datetime.datetime :param end_date: A date beyond which your DAG won't run, leave to None for open ended scheduling :type end_date: datetime.datetime :param template_searchpath: This list of folders (non relative) defines where jinja will look for your templates. Order matters. Note that jinja/airflow includes the path of your DAG file by default :type template_searchpath: string or list of stings :param user_defined_macros: a dictionary of macros that will be exposed in your jinja templates. For example, passing ``dict(foo='bar')`` to this argument allows you to ``{{ foo }}`` in all jinja templates related to this DAG. Note that you can pass any type of object here. :type user_defined_macros: dict :param default_args: A dictionary of default parameters to be used as constructor keyword parameters when initialising operators. Note that operators have the same hook, and precede those defined here, meaning that if your dict contains `'depends_on_past': True` here and `'depends_on_past': False` in the operator's call `default_args`, the actual value will be `False`. :type default_args: dict :param params: a dictionary of DAG level parameters that are made accessible in templates, namespaced under `params`. These params can be overridden at the task level. :type params: dict :param concurrency: the number of task instances allowed to run concurrently :type concurrency: int :param max_active_runs: maximum number of active DAG runs, beyond this number of DAG runs in a running state, the scheduler won't create new active DAG runs :type max_active_runs: int :param dagrun_timeout: specify how long a DagRun should be up before timing out / failing, so that new DagRuns can be created :type dagrun_timeout: datetime.timedelta :param sla_miss_callback: specify a function to call when reporting SLA timeouts. :type sla_miss_callback: types.FunctionType """
self, dag_id, schedule_interval=timedelta(days=1), start_date=None, end_date=None, full_filepath=None, template_searchpath=None, user_defined_macros=None, default_args=None, concurrency=configuration.getint('core', 'dag_concurrency'), max_active_runs=configuration.getint( 'core', 'max_active_runs_per_dag'), dagrun_timeout=None, sla_miss_callback=None, params=None):
self.user_defined_macros = user_defined_macros self.default_args = default_args or {} self.params = params or {}
# merging potentially conflicting default_args['params'] into params if 'params' in self.default_args: self.params.update(self.default_args['params']) del self.default_args['params']
validate_key(dag_id) self.task_dict = dict() self.dag_id = dag_id self.start_date = start_date self.end_date = end_date self.schedule_interval = schedule_interval if schedule_interval in cron_presets: self._schedule_interval = cron_presets.get(schedule_interval) elif schedule_interval == '@once': self._schedule_interval = None else: self._schedule_interval = schedule_interval self.full_filepath = full_filepath if full_filepath else '' if isinstance(template_searchpath, six.string_types): template_searchpath = [template_searchpath] self.template_searchpath = template_searchpath self.parent_dag = None # Gets set when DAGs are loaded self.last_loaded = datetime.now() self.safe_dag_id = dag_id.replace('.', '__dot__') self.concurrency = concurrency self.max_active_runs = max_active_runs self.dagrun_timeout = dagrun_timeout self.sla_miss_callback = sla_miss_callback
self._comps = { 'dag_id', 'task_ids', 'parent_dag', 'start_date', 'schedule_interval', 'full_filepath', 'template_searchpath', 'last_loaded', }
return "<DAG: {self.dag_id}>".format(self=self)
return ( type(self) == type(other) and all(self.__dict__.get(c, None) == other.__dict__.get(c, None) for c in self._comps))
return not self == other
return self.dag_id < other.dag_id
hash_components = [type(self)] for c in self._comps: # task_ids returns a list and lists can't be hashed if c == 'task_ids': val = tuple(self.task_dict.keys()) else: val = getattr(self, c, None) try: hash(val) hash_components.append(val) except TypeError: hash_components.append(repr(val)) return hash(tuple(hash_components))
# Context Manager -----------------------------------------------
global _CONTEXT_MANAGER_DAG self._old_context_manager_dag = _CONTEXT_MANAGER_DAG _CONTEXT_MANAGER_DAG = self return self
global _CONTEXT_MANAGER_DAG _CONTEXT_MANAGER_DAG = self._old_context_manager_dag
# /Context Manager ----------------------------------------------
if num: end_date = None return utils_date_range( start_date=start_date, end_date=end_date, num=num, delta=self._schedule_interval)
if isinstance(self._schedule_interval, six.string_types): cron = croniter(self._schedule_interval, dttm) return cron.get_next(datetime) elif isinstance(self._schedule_interval, timedelta): return dttm + self._schedule_interval
if isinstance(self._schedule_interval, six.string_types): cron = croniter(self._schedule_interval, dttm) return cron.get_prev(datetime) elif isinstance(self._schedule_interval, timedelta): return dttm - self._schedule_interval
""" Returns dttm + interval unless dttm is first interval then it returns dttm """ following = self.following_schedule(dttm)
# in case of @once if not following: return dttm
if self.previous_schedule(following) != dttm: return following
return dttm
def tasks(self): return list(self.task_dict.values())
def tasks(self, val): raise AttributeError( 'DAG.tasks can not be modified. Use dag.add_task() instead.')
def task_ids(self): return list(self.task_dict.keys())
def active_task_ids(self): return list(k for k, v in self.task_dict.items() if not v.adhoc)
def active_tasks(self): return [t for t in self.tasks if not t.adhoc]
def filepath(self): """ File location of where the dag object is instantiated """ fn = self.full_filepath.replace(DAGS_FOLDER + '/', '') fn = fn.replace(os.path.dirname(__file__) + '/', '') return fn
def folder(self): """ Folder location of where the dag object is instantiated """ return os.path.dirname(self.full_filepath)
def owner(self): return ", ".join(list(set([t.owner for t in self.tasks])))
""" Returns a boolean indicating whether the concurrency limit for this DAG has been reached """ TI = TaskInstance qry = session.query(func.count(TI.task_id)).filter( TI.dag_id == self.dag_id, TI.task_id.in_(self.task_ids), TI.state == State.RUNNING, ) return qry.scalar() >= self.concurrency
""" Returns a boolean indicating whether this DAG is paused """ qry = session.query(DagModel).filter( DagModel.dag_id == self.dag_id) return qry.value('is_paused')
def latest_execution_date(self): """ Returns the latest date for which at least one task instance exists """ TI = TaskInstance session = settings.Session() execution_date = session.query(func.max(TI.execution_date)).filter( TI.dag_id == self.dag_id, TI.task_id.in_(self.task_ids) ).scalar() session.commit() session.close() return execution_date
def subdags(self): """ Returns a list of the subdag objects associated to this DAG """ # Check SubDag for class but don't check class directly, see # https://github.com/airbnb/airflow/issues/1168 l = [] for task in self.tasks: if ( task.__class__.__name__ == 'SubDagOperator' and hasattr(task, 'subdag')): l.append(task.subdag) l += task.subdag.subdags return l
for t in self.tasks: t.resolve_template_files()
""" Typically called at the end of a script by passing globals() as a parameter. This allows to not explicitly add every single task to the dag explicitly. """ raise NotImplementedError("")
''' Returns a jinja2 Environment while taking into account the DAGs template_searchpath and user_defined_macros ''' searchpath = [self.folder] if self.template_searchpath: searchpath += self.template_searchpath
env = jinja2.Environment( loader=jinja2.FileSystemLoader(searchpath), extensions=["jinja2.ext.do"], cache_size=0) if self.user_defined_macros: env.globals.update(self.user_defined_macros)
return env
""" Simple utility method to set dependency between two tasks that already have been added to the DAG using add_task() """ self.get_task(upstream_task_id).set_downstream( self.get_task(downstream_task_id))
self, session, start_date=None, end_date=None, state=None): TI = TaskInstance if not start_date: start_date = (datetime.today()-timedelta(30)).date() start_date = datetime.combine(start_date, datetime.min.time()) end_date = end_date or datetime.now() tis = session.query(TI).filter( TI.dag_id == self.dag_id, TI.execution_date >= start_date, TI.execution_date <= end_date, TI.task_id.in_([t.task_id for t in self.tasks]), ) if state: tis = tis.filter(TI.state == state) tis = tis.all() return tis
def roots(self): return [t for t in self.tasks if not t.downstream_list]
self, start_date, end_date, state=State.RUNNING, session=None): dates = utils_date_range(start_date, end_date) drs = session.query(DagModel).filter_by(dag_id=self.dag_id).all() for dr in drs: dr.state = State.RUNNING
self, start_date=None, end_date=None, only_failed=False, only_running=False, confirm_prompt=False, include_subdags=True, reset_dag_runs=True, dry_run=False): session = settings.Session() """ Clears a set of task instances associated with the current dag for a specified date range. """ TI = TaskInstance tis = session.query(TI) if include_subdags: # Crafting the right filter for dag_id and task_ids combo conditions = [] for dag in self.subdags + [self]: conditions.append( TI.dag_id.like(dag.dag_id) & TI.task_id.in_(dag.task_ids) ) tis = tis.filter(or_(*conditions)) else: tis = session.query(TI).filter(TI.dag_id == self.dag_id) tis = tis.filter(TI.task_id.in_(self.task_ids))
if start_date: tis = tis.filter(TI.execution_date >= start_date) if end_date: tis = tis.filter(TI.execution_date <= end_date) if only_failed: tis = tis.filter(TI.state == State.FAILED) if only_running: tis = tis.filter(TI.state == State.RUNNING)
if dry_run: tis = tis.all() session.expunge_all() return tis
count = tis.count() do_it = True if count == 0: print("Nothing to clear.") return 0 if confirm_prompt: ti_list = "\n".join([str(t) for t in tis]) question = ( "You are about to delete these {count} tasks:\n" "{ti_list}\n\n" "Are you sure? (yes/no): ").format(**locals()) do_it = utils.helpers.ask_yesno(question)
if do_it: clear_task_instances(tis, session) if reset_dag_runs: self.set_dag_runs_state(start_date, end_date, session=session) else: count = 0 print("Bail. Nothing was cleared.")
session.commit() session.close() return count
# Swiwtcharoo to go around deepcopying objects coming through the # backdoor cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in list(self.__dict__.items()): if k not in ('user_defined_macros', 'params'): setattr(result, k, copy.deepcopy(v, memo))
result.user_defined_macros = self.user_defined_macros result.params = self.params return result
include_upstream=True): """ Returns a subset of the current dag as a deep copy of the current dag based on a regex that should match one or many tasks, and includes upstream and downstream neighbours based on the flag passed. """
dag = copy.deepcopy(self)
regex_match = [ t for t in dag.tasks if re.findall(task_regex, t.task_id)] also_include = [] for t in regex_match: if include_downstream: also_include += t.get_flat_relatives(upstream=False) if include_upstream: also_include += t.get_flat_relatives(upstream=True)
# Compiling the unique list of tasks that made the cut dag.task_dict = {t.task_id: t for t in regex_match + also_include} for t in dag.tasks: # Removing upstream/downstream references to tasks that did not # made the cut t._upstream_task_ids = [ tid for tid in t._upstream_task_ids if tid in dag.task_ids] t._downstream_task_ids = [ tid for tid in t._downstream_task_ids if tid in dag.task_ids] return dag
return task_id in (t.task_id for t in self.tasks)
if task_id in self.task_dict: return self.task_dict[task_id] raise AirflowException("Task {task_id} not found".format(**locals()))
d = {} d['is_picklable'] = True try: dttm = datetime.now() pickled = pickle.dumps(self) d['pickle_len'] = len(pickled) d['pickling_duration'] = "{}".format(datetime.now() - dttm) except Exception as e: logging.exception(e) d['is_picklable'] = False d['stacktrace'] = traceback.format_exc() return d
dag = session.query( DagModel).filter(DagModel.dag_id == self.dag_id).first() dp = None if dag and dag.pickle_id: dp = session.query(DagPickle).filter( DagPickle.id == dag.pickle_id).first() if not dp or dp.pickle != self: dp = DagPickle(dag=self) session.add(dp) self.last_pickled = datetime.now() session.commit() self.pickle_id = dp.id
return dp
""" Shows an ascii tree representation of the DAG """ def get_downstream(task, level=0): print((" " * level * 4) + str(task)) level += 1 for t in task.upstream_list: get_downstream(t, level)
for t in self.roots: get_downstream(t)
''' Add a task to the DAG
:param task: the task you want to add :type task: task ''' if not self.start_date and not task.start_date: raise AirflowException("Task is missing the start_date parameter") if not task.start_date: task.start_date = self.start_date
if task.task_id in self.task_dict: #TODO raise an error in Airflow 2.0 warnings.warn( 'The requested task could not be added to the DAG because a ' 'task with task_id {} is already in the DAG. Starting in ' 'Airflow 2.0, trying to overwrite a task will raise an ' 'exception.'.format(task.task_id), category=PendingDeprecationWarning) else: self.tasks.append(task) self.task_dict[task.task_id] = task task.dag = self
self.task_count = len(self.tasks)
''' Add a list of tasks to the DAG
:param task: a lit of tasks you want to add :type task: list of tasks ''' for task in tasks: self.add_task(task)
BO = BaseOperator session = settings.Session() tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all() for t in tasks: session.delete(t) session.commit() session.merge(self) session.commit()
self, start_date=None, end_date=None, mark_success=False, include_adhoc=False, local=False, executor=None, donot_pickle=configuration.getboolean('core', 'donot_pickle'), ignore_dependencies=False, ignore_first_depends_on_past=False, pool=None): """ Runs the DAG. """ from airflow.jobs import BackfillJob if not executor and local: executor = LocalExecutor() elif not executor: executor = DEFAULT_EXECUTOR job = BackfillJob( self, start_date=start_date, end_date=end_date, mark_success=mark_success, include_adhoc=include_adhoc, executor=executor, donot_pickle=donot_pickle, ignore_dependencies=ignore_dependencies, ignore_first_depends_on_past=ignore_first_depends_on_past, pool=pool) job.run()
""" Exposes a CLI specific to this DAG """ from airflow.bin import cli parser = cli.CLIFactory.get_parser(dag_parser=True) args = parser.parse_args() args.func(args, self)
run_id, execution_date, state, start_date=None, external_trigger=False, conf=None, session=None): """ Creates a dag run from this dag including the tasks associated with this dag. Returns the dag run. :param run_id: defines the the run id for this dag run :type run_id: string :param execution_date: the execution date of this dag run :type execution_date: datetime :param state: the state of the dag run :type state: State :param start_date: the date this dag run should be evaluated :type state_date: datetime :param external_trigger: whether this dag run is externally triggered :type external_trigger: bool :param session: database session :type session: Session """ run = DagRun( dag_id=self.dag_id, run_id=run_id, execution_date=execution_date, start_date=start_date, external_trigger=external_trigger, conf=conf, state=state ) session.add(run) session.commit()
run.dag = self
# create the associated taskinstances # state is None at the moment of creation run.verify_integrity(session=session)
run.refresh_from_db() return run
"User", cascade=False, cascade_backrefs=False, backref='charts')
return self.label
return self.know_event_type
"User", cascade=False, cascade_backrefs=False, backref='known_events') "KnownEventType", cascade=False, cascade_backrefs=False, backref='known_events')
return self.label
# Hiding the value return '{} : {}'.format(self.key, self._val)
if self._val and self.is_encrypted: if not ENCRYPTION_ON: raise AirflowException( "Can't decrypt _val, configuration is missing") return FERNET.decrypt(bytes(self._val, 'utf-8')).decode() else: return self._val
if value: try: self._val = FERNET.encrypt(bytes(value, 'utf-8')).decode() self.is_encrypted = True except NameError: self._val = value self.is_encrypted = False
def val(cls): descriptor=property(cls.get_val, cls.set_val)) obj = session.query(cls).filter(cls.key == key).first() if obj is None: if default_var is not None: return default_var else: raise ValueError('Variable {} does not exist'.format(key)) else: if deserialize_json: return json.loads(obj.val) else: return obj.val
if serialize_json: stored_value = json.dumps(value) else: stored_value = value
session.query(cls).filter(cls.key == key).delete() session.add(Variable(key=key, val=stored_value)) session.flush()
""" Base class for XCom objects. """
DateTime, default=func.now(), nullable=False)
# source information
return '<XCom "{key}" ({task_id} @ {execution_date})>'.format( key=self.key, task_id=self.task_id, execution_date=self.execution_date)
cls, key, value, execution_date, task_id, dag_id, session=None): """ Store an XCom value. """ session.expunge_all()
# remove any duplicate XComs session.query(cls).filter( cls.key == key, cls.execution_date == execution_date, cls.task_id == task_id, cls.dag_id == dag_id).delete()
# insert new XCom session.add(XCom( key=key, value=value, execution_date=execution_date, task_id=task_id, dag_id=dag_id))
session.commit()
cls, execution_date, key=None, task_id=None, dag_id=None, include_prior_dates=False, session=None): """ Retrieve an XCom value, optionally meeting certain criteria """ filters = [] if key: filters.append(cls.key == key) if task_id: filters.append(cls.task_id == task_id) if dag_id: filters.append(cls.dag_id == dag_id) if include_prior_dates: filters.append(cls.execution_date <= execution_date) else: filters.append(cls.execution_date == execution_date)
query = ( session.query(cls.value) .filter(and_(*filters)) .order_by(cls.execution_date.desc(), cls.timestamp.desc()) .limit(1))
result = query.first() if result: return result.value
cls, execution_date, key=None, task_ids=None, dag_ids=None, include_prior_dates=False, limit=100, session=None): """ Retrieve an XCom value, optionally meeting certain criteria """ filters = [] if key: filters.append(cls.key == key) if task_ids: filters.append(cls.task_id.in_(as_tuple(task_ids))) if dag_ids: filters.append(cls.dag_id.in_(as_tuple(dag_ids))) if include_prior_dates: filters.append(cls.execution_date <= execution_date) else: filters.append(cls.execution_date == execution_date)
query = ( session.query(cls) .filter(and_(*filters)) .order_by(cls.execution_date.desc(), cls.timestamp.desc()) .limit(limit))
return query.all()
if isinstance(xcoms, XCom): xcoms = [xcoms] for xcom in xcoms: if not isinstance(xcom, XCom): raise TypeError('Expected XCom; received {}'.format( xcom.__class__.__name__)) session.delete(xcom) session.commit()
""" DagRun describes an instance of a Dag. It can be created by the scheduler (for regular runs) or by an external trigger """
Index('dr_run_id', dag_id, run_id, unique=True), )
return ( '<DagRun {dag_id} @ {execution_date}: {run_id}, ' 'externally triggered: {external_trigger}>' ).format( dag_id=self.dag_id, execution_date=self.execution_date, run_id=self.run_id, external_trigger=self.external_trigger)
return prefix.format(date.isoformat()[:19])
""" Reloads the current dagrun from the database :param session: database session """ DR = DagRun
dr = session.query(DR).filter( DR.dag_id == self.dag_id, DR.execution_date == self.execution_date, DR.run_id == self.run_id ).one() if dr: self.id = dr.id self.state = dr.state
state=None, external_trigger=None, session=None): """ Returns a set of dag runs for the given search criteria. :param dag_id: the dag_id to find dag runs for :type dag_id: integer, list :param run_id: defines the the run id for this dag run :type run_id: string :param execution_date: the execution date :type execution_date: datetime :param state: the state of the dag run :type state: State :param external_trigger: whether this dag run is externally triggered :type external_trigger: bool :param session: database session :type session: Session """ DR = DagRun
qry = session.query(DR) if dag_id: qry = qry.filter(DR.dag_id == dag_id) if run_id: qry = qry.filter(DR.run_id == run_id) if execution_date: qry = qry.filter(DR.execution_date == execution_date) if state: qry = qry.filter(DR.state == state) if external_trigger: qry = qry.filter(DR.external_trigger == external_trigger)
dr = qry.order_by(DR.execution_date).all()
return dr
""" Returns the task instances for this dag run """ TI = TaskInstance tis = session.query(TI).filter( TI.dag_id == self.dag_id, TI.execution_date == self.execution_date, ) if state: if isinstance(state, six.string_types): tis = tis.filter(TI.state == state) else: # this is required to deal with NULL values if None in state: tis = tis.filter(or_( TI.state.in_(state), TI.state.is_(None)) ) else: tis = tis.filter(TI.state.in_(state))
return tis.all()
""" Returns the task instance specified by task_id for this dag run :param task_id: the task id """
TI = TaskInstance ti = session.query(TI).filter( TI.dag_id == self.dag_id, TI.execution_date == self.execution_date, TI.task_id == task_id ).one()
return ti
""" Returns the Dag associated with this DagRun :param session: database session :return: DAG """ if not self.dag: raise AirflowException("The DAG (.dag) for {} needs to be set" .format(self))
return self.dag
""" Determines the overall state of the DagRun based on the state of its TaskInstances. :returns State: """
dag = self.get_dag() tis = self.get_task_instances(session=session)
logging.info("Updating state for {} considering {} task(s)" .format(self, len(tis)))
for ti in tis: # skip in db? if ti.state == State.REMOVED: tis.remove(ti) else: ti.task = dag.get_task(ti.task_id)
# pre-calculate # db is faster start_dttm = datetime.now() unfinished_tasks = self.get_task_instances( state=State.unfinished(), session=session ) none_depends_on_past = all(t.task.depends_on_past for t in unfinished_tasks)
# small speed up if unfinished_tasks and none_depends_on_past: # todo: this can actually get pretty slow: one task costs between 0.01-015s no_dependencies_met = all(not t.are_dependencies_met(session=session) for t in unfinished_tasks)
duration = (datetime.now() - start_dttm).total_seconds() * 1000 Stats.timing("dagrun.dependency-check.{}.{}". format(self.dag_id, self.execution_date), duration)
# future: remove the check on adhoc tasks (=active_tasks) if len(tis) == len(dag.active_tasks): # if any roots failed, the run failed root_ids = [t.task_id for t in dag.roots] roots = [t for t in tis if t.task_id in root_ids]
if any(r.state in (State.FAILED, State.UPSTREAM_FAILED) for r in roots): logging.info('Marking run {} failed'.format(self)) self.state = State.FAILED
# if all roots succeeded, the run succeeded elif all(r.state in (State.SUCCESS, State.SKIPPED) for r in roots): logging.info('Marking run {} successful'.format(self)) self.state = State.SUCCESS
# if *all tasks* are deadlocked, the run failed elif unfinished_tasks and none_depends_on_past and no_dependencies_met: logging.info( 'Deadlock; marking run {} failed'.format(self)) self.state = State.FAILED
# finally, if the roots aren't done, the dag is still running else: self.state = State.RUNNING
# todo: determine we want to use with_for_update to make sure to lock the run session.merge(self) session.commit()
return self.state
""" Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. """ dag = self.get_dag() tis = self.get_task_instances(session=session)
# check for removed tasks task_ids = [] for ti in tis: task_ids.append(ti.task_id) try: dag.get_task(ti.task_id) except AirflowException: if self.state is not State.RUNNING: ti.state = State.REMOVED
# check for missing tasks for task in dag.tasks: if task.adhoc: continue
if task.task_id not in task_ids: ti = TaskInstance(task, self.execution_date) session.add(ti)
session.commit()
return self.pool
def used_slots(self, session): """ Returns the number of slots used at the moment """ running = ( session .query(TaskInstance) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state == State.RUNNING) .count() ) return running
def queued_slots(self, session): """ Returns the number of slots used at the moment """ return ( session .query(TaskInstance) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state == State.QUEUED) .count() )
def open_slots(self, session): """ Returns the number of slots open at the moment """ used_slots = self.used_slots(session=session) return self.slots - used_slots
""" Model that stores a history of the SLA that have been missed. It is used to keep track of SLA failures over time and to avoid double triggering alert emails. """
return str(( self.dag_id, self.task_id, self.execution_date.isoformat()))
|