Coverage for airflow/operators/sensors.py : 32%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
|
# -*- coding: utf-8 -*- # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.
''' Sensor operators are derived from this class an inherit these attributes.
Sensor operators keep executing at a time interval and succeed when a criteria is met and fail if and when they time out.
:param soft_fail: Set to true to mark the task as SKIPPED on failure :type soft_fail: bool :param poke_interval: Time in seconds that the job should wait in between each tries :type poke_interval: int :param timeout: Time, in seconds before the task times out and fails. :type timeout: int '''
self, poke_interval=60, timeout=60*60*24*7, soft_fail=False, *args, **kwargs): super(BaseSensorOperator, self).__init__(*args, **kwargs) self.poke_interval = poke_interval self.soft_fail = soft_fail self.timeout = timeout
''' Function that the sensors defined while deriving this class should override. ''' raise AirflowException('Override me.')
started_at = datetime.now() while not self.poke(context): if (datetime.now() - started_at).total_seconds() > self.timeout: if self.soft_fail: raise AirflowSkipException('Snap. Time is OUT.') else: raise AirflowSensorTimeout('Snap. Time is OUT.') sleep(self.poke_interval) logging.info("Success criteria met. Exiting.")
""" Runs a sql statement until a criteria is met. It will keep trying until sql returns no row, or if the first cell in (0, '0', '').
:param conn_id: The connection to run the sensor against :type conn_id: string :param sql: The sql to run. To pass, it needs to return at least one cell that contains a non-zero / empty string value. """
def __init__(self, conn_id, sql, *args, **kwargs): self.sql = sql self.conn_id = conn_id super(SqlSensor, self).__init__(*args, **kwargs)
hook = BaseHook.get_connection(self.conn_id).get_hook()
logging.info('Poking: ' + self.sql) records = hook.get_records(self.sql) if not records: return False else: if str(records[0][0]) in ('0', '',): return False else: return True print(records[0][0])
""" An alternative to the HivePartitionSensor that talk directly to the MySQL db. This was created as a result of observing sub optimal queries generated by the Metastore thrift service when hitting subpartitioned tables. The Thrift service's queries were written in a way that wouldn't leverage the indexes.
:param schema: the schema :type schema: str :param table: the table :type table: str :param partition_name: the partition name, as defined in the PARTITIONS table of the Metastore. Order of the fields does matter. Examples: ``ds=2016-01-01`` or ``ds=2016-01-01/sub=foo`` for a sub partitioned table :type partition_name: str :param mysql_conn_id: a reference to the MySQL conn_id for the metastore :type mysql_conn_id: str """
self, table, partition_name, schema="default", mysql_conn_id="metastore_mysql", *args, **kwargs):
self.partition_name = partition_name self.table = table self.schema = schema self.first_poke = True self.conn_id = mysql_conn_id super(SqlSensor, self).__init__(*args, **kwargs)
if self.first_poke: self.first_poke = False if '.' in self.table: self.schema, self.table = self.table.split('.') self.sql = """ SELECT 'X' FROM PARTITIONS A0 LEFT OUTER JOIN TBLS B0 ON A0.TBL_ID = B0.TBL_ID LEFT OUTER JOIN DBS C0 ON B0.DB_ID = C0.DB_ID WHERE B0.TBL_NAME = '{self.table}' AND C0.NAME = '{self.schema}' AND A0.PART_NAME = '{self.partition_name}'; """.format(self=self) return super(MetastorePartitionSensor, self).poke(context)
""" Waits for a task to complete in a different DAG
:param external_dag_id: The dag_id that contains the task you want to wait for :type external_dag_id: string :param external_task_id: The task_id that contains the task you want to wait for :type external_task_id: string :param allowed_states: list of allowed states, default is ``['success']`` :type allowed_states: list :param execution_delta: time difference with the previous execution to look at, the default is the same execution_date as the current task. For yesterday, use [positive!] datetime.timedelta(days=1) :type execution_delta: datetime.timedelta """
self, external_dag_id, external_task_id, allowed_states=None, execution_delta=None, *args, **kwargs): super(ExternalTaskSensor, self).__init__(*args, **kwargs) self.allowed_states = allowed_states or [State.SUCCESS] self.execution_delta = execution_delta self.external_dag_id = external_dag_id self.external_task_id = external_task_id
if self.execution_delta: dttm = context['execution_date'] - self.execution_delta else: dttm = context['execution_date']
logging.info( 'Poking for ' '{self.external_dag_id}.' '{self.external_task_id} on ' '{dttm} ... '.format(**locals())) TI = TaskInstance
session = settings.Session() count = session.query(TI).filter( TI.dag_id == self.external_dag_id, TI.task_id == self.external_task_id, TI.state.in_(self.allowed_states), TI.execution_date == dttm, ).count() session.commit() session.close() return count
""" Waits for a partition to show up in Hive
:param table: The name of the table to wait for, supports the dot notation (my_database.my_table) :type table: string :param partition: The partition clause to wait for. This is passed as is to the Metastore Thrift client "get_partitions_by_filter" method, and apparently supports SQL like notation as in `ds='2015-01-01' AND type='value'` and > < sings as in "ds>=2015-01-01" :type partition: string :param metastore_conn_id: reference to the metastore thrift service connection id :type metastore_conn_id: str """
self, table, partition="ds='{{ ds }}'", metastore_conn_id='metastore_default', schema='default', poke_interval=60*3, *args, **kwargs): super(HivePartitionSensor, self).__init__( poke_interval=poke_interval, *args, **kwargs) if not partition: partition = "ds='{{ ds }}'" self.metastore_conn_id = metastore_conn_id self.table = table self.partition = partition self.schema = schema
if '.' in self.table: self.schema, self.table = self.table.split('.') logging.info( 'Poking for table {self.schema}.{self.table}, ' 'partition {self.partition}'.format(**locals())) if not hasattr(self, 'hook'): import airflow.hooks.hive_hooks self.hook = airflow.hooks.hive_hooks.HiveMetastoreHook( metastore_conn_id=self.metastore_conn_id) return self.hook.check_for_partition( self.schema, self.table, self.partition)
""" Waits for a file or folder to land in HDFS """
self, filepath, hdfs_conn_id='hdfs_default', *args, **kwargs): super(HdfsSensor, self).__init__(*args, **kwargs) self.filepath = filepath self.hdfs_conn_id = hdfs_conn_id
import airflow.hooks.hdfs_hook sb = airflow.hooks.hdfs_hook.HDFSHook(self.hdfs_conn_id).get_conn() logging.getLogger("snakebite").setLevel(logging.WARNING) logging.info( 'Poking for file {self.filepath} '.format(**locals())) try: files = [f for f in sb.ls([self.filepath])] except: return False return True
""" Waits for a file or folder to land in HDFS """
self, filepath, webhdfs_conn_id='webhdfs_default', *args, **kwargs): super(WebHdfsSensor, self).__init__(*args, **kwargs) self.filepath = filepath self.webhdfs_conn_id = webhdfs_conn_id
c = airflow.hooks.webhdfs_hook.WebHDFSHook(self.webhdfs_conn_id) logging.info( 'Poking for file {self.filepath} '.format(**locals())) return c.check_for_path(hdfs_path=self.filepath)
""" Waits for a key (a file-like instance on S3) to be present in a S3 bucket. S3 being a key/value it does not support folders. The path is just a key a resource.
:param bucket_key: The key being waited on. Supports full s3:// style url or relative path from root level. :type bucket_key: str :param bucket_name: Name of the S3 bucket :type bucket_name: str :param wildcard_match: whether the bucket_key should be interpreted as a Unix wildcard pattern :type wildcard_match: bool :param s3_conn_id: a reference to the s3 connection :type s3_conn_id: str """
self, bucket_key, bucket_name=None, wildcard_match=False, s3_conn_id='s3_default', *args, **kwargs): super(S3KeySensor, self).__init__(*args, **kwargs) session = settings.Session() db = session.query(DB).filter(DB.conn_id == s3_conn_id).first() if not db: raise AirflowException("conn_id doesn't exist in the repository") # Parse if bucket_name is None: parsed_url = urlparse(bucket_key) if parsed_url.netloc == '': raise AirflowException('Please provide a bucket_name') else: bucket_name = parsed_url.netloc if parsed_url.path[0] == '/': bucket_key = parsed_url.path[1:] else: bucket_key = parsed_url.path self.bucket_name = bucket_name self.bucket_key = bucket_key self.wildcard_match = wildcard_match self.s3_conn_id = s3_conn_id session.commit() session.close()
import airflow.hooks.S3_hook hook = airflow.hooks.S3_hook.S3Hook(s3_conn_id=self.s3_conn_id) full_url = "s3://" + self.bucket_name + "/" + self.bucket_key logging.info('Poking for key : {full_url}'.format(**locals())) if self.wildcard_match: return hook.check_for_wildcard_key(self.bucket_key, self.bucket_name) else: return hook.check_for_key(self.bucket_key, self.bucket_name)
""" Waits for a prefix to exist. A prefix is the first part of a key, thus enabling checking of constructs similar to glob airfl* or SQL LIKE 'airfl%'. There is the possibility to precise a delimiter to indicate the hierarchy or keys, meaning that the match will stop at that delimiter. Current code accepts sane delimiters, i.e. characters that are NOT special characters in the Python regex engine.
:param bucket_name: Name of the S3 bucket :type bucket_name: str :param prefix: The prefix being waited on. Relative path from bucket root level. :type prefix: str :param delimiter: The delimiter intended to show hierarchy. Defaults to '/'. :type delimiter: str """
self, bucket_name, prefix, delimiter='/', s3_conn_id='s3_default', *args, **kwargs): super(S3PrefixSensor, self).__init__(*args, **kwargs) session = settings.Session() db = session.query(DB).filter(DB.conn_id == s3_conn_id).first() if not db: raise AirflowException("conn_id doesn't exist in the repository") # Parse self.bucket_name = bucket_name self.prefix = prefix self.delimiter = delimiter self.full_url = "s3://" + bucket_name + '/' + prefix self.s3_conn_id = s3_conn_id session.commit() session.close()
logging.info('Poking for prefix : {self.prefix}\n' 'in bucket s3://{self.bucket_name}'.format(**locals())) import airflow.hooks.S3_hook hook = airflow.hooks.S3_hook.S3Hook(s3_conn_id=self.s3_conn_id) return hook.check_for_prefix( prefix=self.prefix, delimiter=self.delimiter, bucket_name=self.bucket_name)
""" Waits until the specified time of the day.
:param target_time: time after which the job succeeds :type target_time: datetime.time """
def __init__(self, target_time, *args, **kwargs): super(TimeSensor, self).__init__(*args, **kwargs) self.target_time = target_time
logging.info( 'Checking if the time ({0}) has come'.format(self.target_time)) return datetime.now().time() > self.target_time
""" Waits for a timedelta after the task's execution_date + schedule_interval. In Airflow, the daily task stamped with ``execution_date`` 2016-01-01 can only start running on 2016-01-02. The timedelta here represents the time after the execution period has closed.
:param delta: time length to wait after execution_date before succeeding :type delta: datetime.timedelta """
def __init__(self, delta, *args, **kwargs): super(TimeDeltaSensor, self).__init__(*args, **kwargs) self.delta = delta
dag = context['dag'] target_dttm = dag.following_schedule(context['execution_date']) target_dttm += self.delta logging.info('Checking if the time ({0}) has come'.format(target_dttm)) return datetime.now() > target_dttm
""" Executes a HTTP get statement and returns False on failure: 404 not found or response_check function returned False
:param http_conn_id: The connection to run the sensor against :type http_conn_id: string :param endpoint: The relative part of the full url :type endpoint: string :param params: The parameters to be added to the GET url :type params: a dictionary of string key/value pairs :param headers: The HTTP headers to be added to the GET request :type headers: a dictionary of string key/value pairs :param response_check: A check against the 'requests' response object. Returns True for 'pass' and False otherwise. :type response_check: A lambda or defined function. :param extra_options: Extra options for the 'requests' library, see the 'requests' documentation (options to modify timeout, ssl, etc.) :type extra_options: A dictionary of options, where key is string and value depends on the option that's being modified. """
endpoint, http_conn_id='http_default', params=None, headers=None, response_check=None, extra_options=None, *args, **kwargs): super(HttpSensor, self).__init__(*args, **kwargs) self.endpoint = endpoint self.http_conn_id = http_conn_id self.params = params or {} self.headers = headers or {} self.extra_options = extra_options or {} self.response_check = response_check
self.hook = hooks.http_hook.HttpHook(method='GET', http_conn_id=http_conn_id)
logging.info('Poking: ' + self.endpoint) try: response = self.hook.run(self.endpoint, data=self.params, headers=self.headers, extra_options=self.extra_options) if self.response_check: # run content check on response return self.response_check(response) except AirflowException as ae: if str(ae).startswith("404"): return False
raise ae
return True |