Airflow自定义插件, 使用datax抽数

  • 2019 年 10 月 3 日
  • 笔记

Airflow自定义插件

Airflow之所以受欢迎的一个重要因素就是它的插件机制。Python成熟类库可以很方便的引入各种插件。在我们实际工作中,必然会遇到官方的一些插件不足够满足需求的时候。这时候,我们可以编写自己的插件。不需要你了解内部原理,甚至不需要很熟悉Python, 反正我连蒙带猜写的。

插件分类

Airflow的插件分为Operator和Sensor两种。Operator是具体要执行的任务插件, Sensor则是条件传感器,当我需要设定某些依赖的时候可以通过不同的sensor来感知条件是否满足。

Airflow对插件提供的支持

插件肯定是Python文件了,系统必然需要加载才能执行。Airflow提供了一个简单插件管理器,会扫描$AIRFLOW_HOME/plugins加载我们的插件。

所以,我们只需要将写好的插件放入这个目录下就可以了。

插件语法

Operator和Sensor都声明了需要的参数,Operator通过调用execute来执行, sensor通过poke来确认。以Operator为例子。

插件的使用过程为:

dag  -> operator -> hook

Hook就是任务执行的具体操作了。

Operator通过继承BaseOperator实现对dag相关属性的绑定, Hook通过继承BaseHook实现对系统配置和资源获取的一些封装。

自定义一个通知插件NotifyOperator

前文https://www.cnblogs.com/woshimrf/p/airflow-dag.html 提到我们通过自定义通知实现多功能任务告警,以下就是一个demo。

文件结构如下:

plugins  │   ├── hooks  │   └── operators

NotifyOperator

首先,在operators目录下创建一个Operator.

# -*- coding: utf-8 -*-  #    from hooks.notify_hook import NotifyHook  from airflow.operators.bash_operator import BaseOperator    class NotifyOperator(BaseOperator):      """      使用通知服务发送通知        :param message: 内容      :type message: str or dict      :param receivers: 英文逗号分割的罗盘账号      :type receivers: str      :param subject: 邮件主题      :type subject: str      """      template_fields = ('message', 'subject')        @apply_defaults      def __init__(self,                   subject=None,                   message=None,                   receivers=None,                   *args,                   **kwargs):          super().__init__(*args, **kwargs)          self.message = message          self.receivers = receivers          self.subject = subject        def execute(self, context):          self.log.info('Sending notify message. receivers:{}  message:{}'.format(self.receivers, self.message))          hook = NotifyHook(              subject=self.subject,              message=self.message,              receivers=self.receivers          )          hook.send()
  • 继承BaseOperator
  • 引入NotifyHook, 这个还没创建,等下创建
  • template_fields, 想要使用模板变量替换,比如{{ds}}, 字段必须声明到template_fields
  • Operator执行的时候会调用execute方法, 这个就是执行的内容

上面可以看出,operator就是接口声明。

NotifyHook

在hooks目录下创建NotifyHook

# -*- coding: utf-8 -*-  #    import json  import requests  from airflow import AirflowException  from airflow.hooks.http_hook import HttpHook    class NotifyHook(HttpHook):      """      使用通知服务发送通知        :param send_type: 通知类型选填 MAIL,DINGDING,SMS,选填多个时中间用英文逗号隔开      :type send_type: str      :param message: 内容      :type message: str or dict      :param receivers: 英文逗号分割的账号      :type receivers: str      :param subject: 邮件主题      :type subject: str      """        def __init__(self,                   notify_conn_id='notify_default',                   send_type='MAIL',                   subject=None,                   message=None,                   receivers=None,                   *args,                   **kwargs                   ):          super().__init__(http_conn_id=notify_conn_id, *args, **kwargs)          self.send_type = send_type          self.message = message          self.subject = subject          self.receivers = receivers          def _build_message(self):          """          构建data          """          data = {                  "content": self.message,                  "contentType": "HTML",                  "receivers": self.receivers,                  "sendType": self.send_type,                  "sender": '【Airflow】',                  "subject": '【Airflow】' + self.subject              }          return json.dumps(data)        def get_conn(self, headers=None):          """          Overwrite HttpHook get_conn because just need base_url and headers and          not don't need generic params            :param headers: additional headers to be passed through as a dictionary          :type headers: dict          """          self.base_url = 'http://notify.ryan-miao.com'          session = requests.Session()          if headers:              session.headers.update(headers)          return session        def send(self):          """          Send Notify message          """          data = self._build_message()          self.log.info('Sending message: %s',  data)          resp = self.run(endpoint='/api/v2/notify/send',                          data=data,                          headers={'Content-Type': 'application/json',                                   'app-id': 'ryan',                                   'app-key': '123456'})          if int(resp.json().get('retCode')) != 0:              raise AirflowException('Send notify message failed, receive error '                                     'message %s', resp.text)          self.log.info('Success Send notify message')
  • 这里使用的我自己的通知服务api调用。因为是http请求,所以直接继承HttpHook来发送请求就可以了。
  • http_conn_id是用来读取数据库中connection里配置的host的,这里直接覆盖,固定我们通知服务的地址。
  • 通过抛出异常的方式来终止服务

如何使用

将上面两个文件放到airflow对应的plugins目录下, airflow就自动加载了。然后,当做任务类型使用

from operators.notify_operator import NotifyOperator    notification = NotifyOperator(                      task_id="we_are_done",                      subject='发送邮件',                      message='content',                      receivers='ryanmiao'                  )  

也可以直接执行。比如,我们前面提到任务失败告警可以自定义通知。

from operators.notify_operator import NotifyOperator    def mail_failure_callback(receivers):      """      失败后邮件通知      :receivers  接收人,多个接收人用英文逗号分开      """        def mail_back(context):          subject="【执行失败】DAG {} TASK {} ds {}".format(                                                  context['task_instance'].dag_id,                                                  context['task_instance'].task_id,                                                  context['ds'])          message="【执行失败】DAG:  {};<br> TASK:  {} <br>; ds {} <br>;   原因: {} .<br>"                   "查看地址: http://airflow.ryan-miao.com/admin/airflow/tree?dag_id={}"                   .format(                  context['task_instance'].dag_id,                  context['task_instance'].task_id,                  context['ds'],                  context['exception'],                  context['task_instance'].dag_id)            return NotifyOperator(                      task_id="mail_failed_notify_callback",                      subject=subject,                      message=message,                      receivers=receivers                  ).execute(context)        return mail_back    default_args = {      'owner': 'ryanmiao',      'depends_on_past': False,      'start_date': datetime(2019, 5, 1, 9),      'on_failure_callback': mail_failure_callback(receivers='ryanmiao'),      'retries': 0  }    dag = DAG(      'example', default_args=default_args, schedule_interval=None)

自定义一个RDBMS2Hive插件

我们任务调度有个常见的服务是数据抽取到Hive,现在来制作这个插件,可以从关系数据库中读取数据,然后存储到hive。这样,用户只要在airflow配置一下要抽数的database, table和目标hive table就可以实现每天数据入库了。

异构数据传输转换工具很多, 最简单的就是使用原生的dump工具,将数据dump下来,然后import到另一个数据库里。

比如postgres dump

${sql}查询的列导出到文件${export_data_file}

psql -h$SRC_HOST_IP -U$SRC_USER_NAME -d$SRC_DB -p$SRC_HOST_PORT -c  "copy (${sql}) to '${export_data_file}' WITH NULL AS ''"

然后导入hive

LOAD DATA LOCAL INPATH '${export_data_file}' INTO TABLE $TAR_TABLENAME PARTITION (BIZDATE='$BIZ_DATE')

对postgres来说,copy是最快的方案了, 但可能会遇到t,n等各种转义符号,导出的txt文件或者cvs文件格式就会混乱,需要做对应符号转义处理。

同样, mysql 可以直接把数据查询出来

cat search.sql | mysql -h"$SRC_HOST_IP" -u"$SRC_USER_NAME" -p"$SRC_USER_PWD" -P"$SRC_HOST_PORT" -D"$SRC_DB" --default-character-set=${mysql_charset} -N -s | sed "s/NULL/\\N/ig;s/\\\\n//ig" > result.txt

上述这些命令行的好处就是快,不好的地方在于shell命令的脆弱性和错误处理。最终,选择了集成化的数据转换工具datax. datax是阿里巴巴开源的一款异构数据源同步工具, 虽然看起来不怎么更新了,但简单使用还是可以的。https://github.com/alibaba/DataX

datax的用法相对简单,按照文档配置一下读取数据源和目标数据源,然后执行调用就可以了。可以当做命令行工具来使用。

结合airflow,可以自己实现datax插件。通过读取connections拿到数据源链接配置,然后生成datax的配置文件json,最后调用datax执行。下面是一个从pg或者mysql读取数据,导入hive的插件实现。

主要思路是:

  1. hdfs创建一个目录
  2. 生成datax配置文件
  3. datax执行配置文件,将数据抽取到hdfs
  4. hive命令行load hdfs

RDBMS2HiveOperator

# -*- coding: utf-8 -*-  #  #    """  postgres或者mysql 入库到hdfs  """  import os  import signal    from hooks.rdbms_to_hive_hook import RDBMS2HiveHook  from airflow.exceptions import AirflowException  from airflow.models import BaseOperator      class RDBMS2HiveOperator(BaseOperator):      """      传输pg到hive      https://github.com/alibaba/DataX        :param conn_id: pg连接id      :param query_sql : pg查询语句      :param split_pk  : pg分割主键, NONE表示不分割,指定后可以多线程分割,加快传输      :param hive_db   : hive的db      :param hive_table: hive的table      :param hive_table_column  column数组, column={name:a, type: int} 或者逗号分割的字符串, column=a,b,c      :param hive_table_partition 分区bizdate值      """      template_fields = ('query_sql',  'hive_db', 'hive_table','hive_table_partition')      ui_color = '#edd5f1'        @apply_defaults      def __init__(self,                   conn_id,                   query_sql,                   hive_db,                   hive_table,                   hive_table_column,                   hive_table_partition,                   split_pk=None,                   *args,                   **kwargs):          super().__init__(*args, **kwargs)          self.conn_id = conn_id          self.query_sql = query_sql          self.split_pk = split_pk          self.hive_db = hive_db          self.hive_table = hive_table          self.hive_table_column = hive_table_column          self.hive_table_partition = hive_table_partition          def execute(self, context):          """          Execute          """          task_id = context['task_instance'].dag_id + "#" + context['task_instance'].task_id            self.hook = RDBMS2HiveHook(                          task_id = task_id,                          conn_id = self.conn_id,                          query_sql = self.query_sql,                          split_pk=self.split_pk,                          hive_db=self.hive_db,                          hive_table=self.hive_table,                          hive_table_column=self.hive_table_column,                          hive_table_partition=self.hive_table_partition                          )          self.hook.execute(context=context)          def on_kill(self):          self.log.info('Sending SIGTERM signal to bash process group')          os.killpg(os.getpgid(self.hook.sp.pid), signal.SIGTERM)  

RDBMS2HiveHook

# -*- coding: utf-8 -*-  #    """  datax入库hive  """  import subprocess  import uuid  import json  import os    from airflow.exceptions import AirflowException  from airflow.hooks.base_hook import BaseHook      class RDBMS2HiveHook(BaseHook):      """      Datax执行器      """        def __init__(self,                   task_id,                   conn_id,                   query_sql,                   hive_db,                   hive_table,                   hive_table_column,                   hive_table_partition,                   split_pk=None):          self.task_id = task_id          self.conn = self.get_connection(conn_id)          self.query_sql = query_sql          self.split_pk = split_pk          self.hive_db = hive_db          self.hive_table = hive_table          self.hive_table_partition = hive_table_partition          self.log.info("Using connection to: {}:{}/{}".format(self.conn.host, self.conn.port, self.conn.schema))            self.hive_table_column = hive_table_column          if isinstance(hive_table_column, str):              self.hive_table_column = []              cl = hive_table_column.split(',')              for item in cl:                  hive_table_column_item = {                      "name": item,                      "type": "string"                  }                  self.hive_table_column.append(hive_table_column_item)          def Popen(self, cmd, **kwargs):          """          Remote Popen            :param cmd: command to remotely execute          :param kwargs: extra arguments to Popen (see subprocess.Popen)          :return: handle to subprocess          """          self.sp = subprocess.Popen(              cmd,              stdout=subprocess.PIPE,              stderr=subprocess.STDOUT,              **kwargs)            for line in iter(self.sp.stdout):              self.log.info(line.strip().decode('utf-8'))            self.sp.wait()            self.log.info("Command exited with return code %s", self.sp.returncode)            if self.sp.returncode:              raise AirflowException("Execute command failed")            def generate_setting(self):          """           datax速度等设置          """          self.setting= {              "speed": {                   "byte": 104857600              },              "errorLimit": {                  "record": 0,                  "percentage": 0.02              }          }          return self.setting        def generate_reader(self):          """          datax reader          """          conn_type = 'mysql'          reader_name = 'mysqlreader'          if(self.conn.conn_type == 'postgres'):              conn_type = 'postgresql'              reader_name = 'postgresqlreader'            self.jdbcUrl =  "jdbc:"+conn_type+"://"+self.conn.host.strip()+":"+str(self.conn.port)+"/"+ self.conn.schema.strip()          self.reader =  {              "name": reader_name,              "parameter": {                  "username": self.conn.login.strip(),                  "password": self.conn.password.strip(),                  "connection": [                      {                          "querySql": [                              self.query_sql                          ],                          "jdbcUrl": [                              self.jdbcUrl                          ]                      }                  ]              }          }            return self.reader        def generate_writer(self):          """          datax hdafs writer          """          self.file_type = "text"          self.hdfs_path = "/datax/"+self.hive_db+"/"+self.hive_table+"/"+self.hive_table_partition          self.log.info("临时存储目录:{}".format(self.hdfs_path))          self.writer = {                      "name": "hdfswriter",                      "parameter": {                          "defaultFS": "hdfs://nameservice1",                          "hadoopConfig": {                              "dfs.nameservices": "nameservice1",                              "dfs.ha.automatic-failover.enabled.nameservice1": True,                              "ha.zookeeper.quorum": "bigdata2-prod-nn01.ryan-miao.com:2181,bigdata2-prod-nn02.ryan-miao.com:2181,bigdata2-prod-nn03.ryan-miao.com:2181",                              "dfs.ha.namenodes.nameservice1": "namenode117,namenode124",                              "dfs.namenode.rpc-address.nameservice1.namenode117": "bigdata2-prod-nn01.ryan-miao.com:8020",                              "dfs.namenode.servicerpc-address.nameservice1.namenode117": "bigdata2-prod-nn01.ryan-miao.com:8022",                              "dfs.namenode.http-address.nameservice1.namenode117": "bigdata2-prod-nn01.ryan-miao.com:50070",                              "dfs.namenode.https-address.nameservice1.namenode117": "bigdata2-prod-nn01.ryan-miao.com:50470",                              "dfs.namenode.rpc-address.nameservice1.namenode124": "bigdata2-prod-nn02.ryan-miao.com:8020",                              "dfs.namenode.servicerpc-address.nameservice1.namenode124": "bigdata2-prod-nn02.ryan-miao.com:8022",                              "dfs.namenode.http-address.nameservice1.namenode124": "bigdata2-prod-nn02.ryan-miao.com:50070",                              "dfs.namenode.https-address.nameservice1.namenode124": "bigdata2-prod-nn02.ryan-miao.com:50470",                              "dfs.replication": 3,                              "dfs.client.failover.proxy.provider.nameservice1": "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider"                          },                          "fileType": self.file_type,                          "path": self.hdfs_path,                          "fileName": self.task_id,                          "column": self.hive_table_column,                          "writeMode": "nonConflict",                          "fieldDelimiter": "t"                      }              }          return self.writer        def generate_config(self):          content = [{              "reader": self.generate_reader(),              "writer": self.generate_writer()          }]            job = {              "setting": self.generate_setting(),              "content": content          }            config = {              "job": job          }            self.target_json = json.dumps(config)            # write json to file          self.json_file= '/tmp/datax_json_'+self.task_id+ uuid.uuid1().hex          # 打开一个文件          fo = open(self.json_file, "w")          fo.write(self.target_json)          fo.close()          self.log.info("write config json {}".format(self.json_file))          return self.json_file            def execute(self, context):          self.generate_config()          # check hdfs_path          hdfs_path = self.hdfs_path          if(not hdfs_path.startswith('/datax/')):              raise AirflowException("hdfs路径填写错误,不在/datax目录下")            # 创建目录          cmd = ['hadoop', 'fs', '-mkdir', '-p', hdfs_path]          self.Popen(cmd)            # 删除文件          if(not hdfs_path.startswith('/datax/')):              raise AirflowException("hdfs路径填写错误,不在/datax目录下")            files_path = hdfs_path+"/*";          try:              cmd = ['hadoop', 'fs', '-rm', files_path]              self.Popen(cmd)          except Exception:              self.log.info('ignore err, just make sure the dir is clean')              pass              # 上传文件          datax_home = '/data/opt/datax/bin'          cmd = [ 'python', datax_home + '/datax.py', self.json_file]          self.Popen(cmd)          # 删除配置文件          os.remove(self.json_file)            # hive加载          #hive load data from hdfs          hql = "LOAD DATA INPATH '"+ hdfs_path + "' OVERWRITE  INTO TABLE "                 + self.hive_db+"."+self.hive_table + " PARTITION (bizdate="+ self.hive_table_partition  +")"          cmd = ['hive', '-e', """ + hql + """]          self.Popen(cmd)

如何使用

  1. admin登录airflow
  2. 配置connection, 配置pg或者mysql的数据库
  3. 修改hdfs集群配置信息
  4. 创建一个DAG
from airflow import DAG    from operators.rdbms_to_hive_operator import RDBMS2HiveOperator  from datetime import datetime, timedelta  from dag_utils import compass_utils      default_args = {      'owner': 'ryanmiao',      'depends_on_past': False,      'start_date': datetime(2019, 5, 1, 9),      'on_failure_callback': compass_utils.failure_callback(dingding_conn_id='dingding_bigdata', receivers='ryanmiao'),      # 'on_success_callback': compass_utils.success_callback(dingding_conn_id='dingding_bigdata', receivers='ryanmiao'),      'retries': 0  }    dag = DAG(      'example_pg2hive', default_args=default_args, schedule_interval=None)    # CREATE TABLE test.pg2hive_test(  #      ftime int,  #      raw_cp_count int,  #      raw_to_delete_cp_count bigint,  #      create_time timestamp  #      )  #  COMMENT '这个是测试datax表'  #  PARTITIONED BY (bizdate int)  # ROW FORMAT DELIMITED  # FIELDS TERMINATED BY 't'  # LINES TERMINATED BY 'n'  #  STORED AS TEXTFILE;    hive_table_column = "ftime,raw_cp_count,raw_to_delete_cp_count,create_time"    t1 = RDBMS2HiveOperator(      task_id='pg2hive',      conn_id='pg_rdb_poi',      query_sql='select ftime, raw_cp_count, raw_to_delete_cp_count, create_time from tbl_poi_report limit 1000',      hive_db='test',      hive_table='pg2hive_test',      hive_table_column=hive_table_column,      hive_table_partition="{{ ds_nodash }}",      dag=dag  )