用深度學習做命名實體識別(五)-模型使用

  • 2019 年 10 月 6 日
  • 筆記

通過本文,你將了解如何基於訓練好的模型,來編寫一個rest風格的命名實體提取介面,傳入一個句子,介面會提取出句子中的人名、地址、組織、公司、產品、時間資訊並返回。

核心模組entity_extractor.py

關鍵函數
# 載入實體識別模型  def person_model_init():     ...    # 預測句子中的實體  def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,              pred_ids,              tokenizer,              sess, max_seq_length):      ...
完整程式碼
# -*- coding: utf-8 -*-    """  基於模型的地址提取  """  __author__ = '程式設計師一一滌生'    import codecs  import os  import pickle  from datetime import datetime  from pprint import pprint  import numpy as np  import tensorflow as tf  from bert_base.bert import tokenization, modeling  from bert_base.train.models import create_model, InputFeatures  from bert_base.train.train_helper import get_args_parser    args = get_args_parser()    def convert(line, model_dir, label_list, tokenizer, batch_size, max_seq_length):      feature = convert_single_example(model_dir, 0, line, label_list, max_seq_length, tokenizer, 'p')      input_ids = np.reshape([feature.input_ids], (batch_size, max_seq_length))      input_mask = np.reshape([feature.input_mask], (batch_size, max_seq_length))      segment_ids = np.reshape([feature.segment_ids], (batch_size, max_seq_length))      label_ids = np.reshape([feature.label_ids], (batch_size, max_seq_length))      return input_ids, input_mask, segment_ids, label_ids    def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,              pred_ids,              tokenizer,              sess, max_seq_length):      with graph.as_default():          start = datetime.now()          # print(id2label)          sentence = tokenizer.tokenize(sentence)          # print('your input is:{}'.format(sentence))          input_ids, input_mask, segment_ids, label_ids = convert(sentence, model_dir, label_list, tokenizer, batch_size,                                                                  max_seq_length)            feed_dict = {input_ids_p: input_ids,                       input_mask_p: input_mask}          # run session get current feed_dict result          pred_ids_result = sess.run([pred_ids], feed_dict)          pred_label_result = convert_id_to_label(pred_ids_result, id2label, batch_size)          # print(pred_ids_result)          print(pred_label_result)          # todo: 組合策略          result = strage_combined(sentence, pred_label_result[0], labels_config)          print('time used: {} sec'.format((datetime.now() - start).total_seconds()))      return result, pred_label_result    def convert_id_to_label(pred_ids_result, idx2label, batch_size):      """      將id形式的結果轉化為真實序列結果      :param pred_ids_result:      :param idx2label:      :return:      """      result = []      for row in range(batch_size):          curr_seq = []          for ids in pred_ids_result[row][0]:              if ids == 0:                  break              curr_label = idx2label[ids]              if curr_label in ['[CLS]', '[SEP]']:                  continue              curr_seq.append(curr_label)          result.append(curr_seq)      return result    def strage_combined(tokens, tags, labels_config):      """      組合策略      :param pred_label_result:      :param types:      :return:      """      def get_output(rs, data, type):          words = []          for i in data:              words.append(str(i.word).replace("#", ""))              # words.append(i.word)          rs[type] = words          return rs      eval = Result(labels_config)      if len(tokens) > len(tags):          tokens = tokens[:len(tags)]      labels_dict = eval.get_result(tokens, tags)      arr = []      for k, v in labels_dict.items():          arr.append((k, v))      rs = {}      for item in arr:          rs = get_output(rs, item[1], item[0])      return rs    def convert_single_example(model_dir, ex_index, example, label_list, max_seq_length, tokenizer, mode):      """      將一個樣本進行分析,然後將字轉化為id, 標籤轉化為id,然後結構化到InputFeatures對象中      :param ex_index: index      :param example: 一個樣本      :param label_list: 標籤列表      :param max_seq_length:      :param tokenizer:      :param mode:      :return:      """      label_map = {}      # 1表示從1開始對label進行index化      for (i, label) in enumerate(label_list, 1):          label_map[label] = i      # 保存label->index 的map      if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):          with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:              pickle.dump(label_map, w)      tokens = example      # tokens = tokenizer.tokenize(example.text)      # 序列截斷      if len(tokens) >= max_seq_length - 1:          tokens = tokens[0:(max_seq_length - 2)]  # -2 的原因是因為序列需要加一個句首和句尾標誌      ntokens = []      segment_ids = []      label_ids = []      ntokens.append("[CLS]")  # 句子開始設置CLS 標誌      segment_ids.append(0)      # append("O") or append("[CLS]") not sure!      label_ids.append(label_map["[CLS]"])  # O OR CLS 沒有任何影響,不過我覺得O 會減少標籤個數,不過拒收和句尾使用不同的標誌來標註,使用LCS 也沒毛病      for i, token in enumerate(tokens):          ntokens.append(token)          segment_ids.append(0)          label_ids.append(0)      ntokens.append("[SEP]")  # 句尾添加[SEP] 標誌      segment_ids.append(0)      # append("O") or append("[SEP]") not sure!      label_ids.append(label_map["[SEP]"])      input_ids = tokenizer.convert_tokens_to_ids(ntokens)  # 將序列中的字(ntokens)轉化為ID形式      input_mask = [1] * len(input_ids)      # padding, 使用      while len(input_ids) < max_seq_length:          input_ids.append(0)          input_mask.append(0)          segment_ids.append(0)          # we don't concerned about it!          label_ids.append(0)          ntokens.append("**NULL**")          # label_mask.append(0)      # print(len(input_ids))      assert len(input_ids) == max_seq_length      assert len(input_mask) == max_seq_length      assert len(segment_ids) == max_seq_length      assert len(label_ids) == max_seq_length      # assert len(label_mask) == max_seq_length      # 結構化為一個類      feature = InputFeatures(          input_ids=input_ids,          input_mask=input_mask,          segment_ids=segment_ids,          label_ids=label_ids,          # label_mask = label_mask      )      return feature    class Pair(object):      def __init__(self, word, start, end, type, merge=False):          self.__word = word          self.__start = start          self.__end = end          self.__merge = merge          self.__types = type        @property      def start(self):          return self.__start        @property      def end(self):          return self.__end        @property      def merge(self):          return self.__merge        @property      def word(self):          return self.__word        @property      def types(self):          return self.__types        @word.setter      def word(self, word):          self.__word = word        @start.setter      def start(self, start):          self.__start = start        @end.setter      def end(self, end):          self.__end = end        @merge.setter      def merge(self, merge):          self.__merge = merge        @types.setter      def types(self, type):          self.__types = type        def __str__(self) -> str:          line = []          line.append('entity:{}'.format(self.__word))          line.append('start:{}'.format(self.__start))          line.append('end:{}'.format(self.__end))          line.append('merge:{}'.format(self.__merge))          line.append('types:{}'.format(self.__types))          return 't'.join(line)    class Result(object):      def __init__(self, labels_config):          self.others = []          self.labels_config = labels_config          self.labels = {}          for la in self.labels_config:              self.labels[la] = []        def get_result(self, tokens, tags):          # 先獲取標註結果          self.result_to_json(tokens, tags)          return self.labels        def result_to_json(self, string, tags):          """          將模型標註序列和輸入序列結合 轉化為結果          :param string: 輸入序列          :param tags: 標註結果          :return:          """          item = {"entities": []}          entity_name = ""          entity_start = 0          idx = 0          last_tag = ''            for char, tag in zip(string, tags):              if tag[0] == "S":                  self.append(char, idx, idx + 1, tag[2:])                  item["entities"].append({"word": char, "start": idx, "end": idx + 1, "type": tag[2:]})              elif tag[0] == "B":                  if entity_name != '':                      self.append(entity_name, entity_start, idx, last_tag[2:])                      item["entities"].append(                          {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})                      entity_name = ""                  entity_name += char                  entity_start = idx              elif tag[0] == "I":                  entity_name += char              elif tag[0] == "O":                  if entity_name != '':                      self.append(entity_name, entity_start, idx, last_tag[2:])                      item["entities"].append(                          {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})                      entity_name = ""              else:                  entity_name = ""                  entity_start = idx              idx += 1              last_tag = tag          if entity_name != '':              self.append(entity_name, entity_start, idx, last_tag[2:])              item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})          return item        def append(self, word, start, end, tag):          if tag in self.labels_config:              self.labels[tag].append(Pair(word, start, end, tag))          else:              self.others.append(Pair(word, start, end, tag))    def person_model_init():      return model_init("person")    def model_init(model_name):      if os.name == 'nt':  # windows path config          model_dir = 'E:/quickstart/deeplearning/nlp_demo/%s/model' % model_name          bert_dir = 'E:/quickstart/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'      else:  # linux path config          model_dir = '/home/yjy/project/deeplearning/nlp_demo/%s/model' % model_name          bert_dir = '/home/yjy/project/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'        batch_size = 1      max_seq_length = 500        print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint")))      if not os.path.exists(os.path.join(model_dir, "checkpoint")):          raise Exception("failed to get checkpoint. going to return ")        # 載入label->id的詞典      with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:          label2id = pickle.load(rf)          id2label = {value: key for key, value in label2id.items()}        with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:          label_list = pickle.load(rf)      num_labels = len(label_list) + 1        gpu_config = tf.ConfigProto()      gpu_config.gpu_options.allow_growth = True      graph = tf.Graph()      sess = tf.Session(graph=graph, config=gpu_config)        with graph.as_default():          print("going to restore checkpoint")          # sess.run(tf.global_variables_initializer())          input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids")          input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask")            bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json'))          (total_loss, logits, trans, pred_ids) = create_model(              bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p,              segment_ids=None,              labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)            saver = tf.train.Saver()          saver.restore(sess, tf.train.latest_checkpoint(model_dir))        tokenizer = tokenization.FullTokenizer(          vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=args.do_lower_case)        return model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length      if __name__ == "__main__":      _model_dir, _batch_size, _id2label, _label_list, _graph, _input_ids_p, _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length = person_model_init()      PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]      while True:          print('input the test sentence:')          _sentence = str(input())          pred_rs, pred_label_result = predict(_sentence, ADDRESS_LABELS, _model_dir, _batch_size, _id2label, _label_list,                                               _graph,                                               _input_ids_p,                                               _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length)          pprint(pred_rs)

編寫rest風格的介面

我們將採用python的flask框架來提供rest介面。

首先,新建一個python項目,項目根路徑下放入以下目錄和文件:
  • bert_base目錄及文件、bert_model_info目錄及文件在上一篇文章 用深度學習做命名實體識別(四)——模型訓練 給出的雲盤項目中可以找到;
  • person目錄下的model就是我們在上一篇文章中訓練得到的命名實體識別模型以及一些附屬文件,在項目的output目錄下可以得到。
然後,創建啟動文件nlp_main.py,內容如下:
# -*- coding: utf-8 -*-    """  flask 入口  """  import os  import nlp_config as nc  from flaskr import create_app, loadProjContext    __author__ = '程式設計師一一滌生'    from flask import jsonify, make_response, redirect    # 載入flask配置資訊  # app = create_app('config.DevelopmentConfig')  app = create_app(nc.config['default'])  # 載入項目上下文資訊  loadProjContext()    @app.errorhandler(404)  def not_found(error):      return make_response(jsonify({'error': 'Not found'}), 404)    @app.errorhandler(400)  def not_found(error):      return make_response(jsonify({'error': '400 Bad Request,參數或參數內容異常'}), 400)    @app.route('/')  def index_sf():      # return render_template('index.html')      return redirect('index.html')    if __name__ == '__main__':      app.run('localhost', 5006, app, use_reloader=False)
接著,創建本flask項目的初始化文件flaskr.py,用於啟動項目的時候預設置和載入一些資訊,內容如下:
# -*- coding: utf-8 -*-  """  flask初始化  """  from logging.config import dictConfig  from flask import Flask  from flask_cors import CORS  import address_ner_resource  import person_ner_resource  from address_ner_resource import address  from entity_extractor import address_model_init, person_model_init  from person_ner_resource import person    __author__ = '程式設計師一一滌生'    def create_app(config_type):      dictConfig({          'version': 1,          'formatters': {'default': {              'format': '[%(asctime)s] %(name)s %(levelname)s in %(module)s %(lineno)d: %(message)s',          }},          'handlers': {'wsgi': {              'class': 'logging.StreamHandler',              'stream': 'ext://flask.logging.wsgi_errors_stream',              'formatter': 'default'          }},          'root': {              'level': 'DEBUG',              # 'level': 'WARN',              # 'level': 'INFO',              'handlers': ['wsgi']          }      })      # 載入flask配置資訊      app = Flask(__name__, static_folder='static', static_url_path='')      # CORS(app, resources=r'/*',origins=['192.168.1.104'])  # r'/*' 是通配符,允許跨域請求本伺服器所有的URL,"origins": '*'表示允許所有ip跨域訪問本伺服器的url      CORS(app, resources={r"/*": {"origins": '*'}})  # r'/*' 是通配符,允許跨域請求本伺服器所有的URL,"origins": '*'表示允許所有ip跨域訪問本伺服器的url      app.config.from_object(config_type)      app.register_blueprint(person, url_prefix='/person')      # 初始化上下文      ctx = app.app_context()      ctx.push()      return app    def loadProjContext():      # 載入人名提取模型      model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = person_model_init()      person_ner_resource.model_dir = model_dir      person_ner_resource.batch_size = batch_size      person_ner_resource.id2label = id2label      person_ner_resource.label_list = label_list      person_ner_resource.graph = graph      person_ner_resource.input_ids_p = input_ids_p      person_ner_resource.input_mask_p = input_mask_p      person_ner_resource.pred_ids = pred_ids      person_ner_resource.tokenizer = tokenizer      person_ner_resource.sess = sess      person_ner_resource.max_seq_length = max_seq_length
然後,創建配置文件nlp_config.py,用於切換生產、開發、測試環境,內容如下:
# -*- coding: utf-8 -*-    """  本模組是Flask的配置模組  """  import os    __author__ = '程式設計師一一滌生'    basedir = os.path.abspath(os.path.dirname(__file__))    class BaseConfig:  # 基本配置類      SECRET_KEY = b'xe4rx04xb5xb2x00xf1xadfxa3xf3Vx03xc5x9fx82$^xa25Oxf0Rxda'      JSONIFY_MIMETYPE = 'application/json; charset=utf-8'  # 默認JSONIFY_MIMETYPE的配置是不帶'; charset=utf-8的'      JSON_AS_ASCII = False  # 若不關閉,使用JSONIFY返回json時中文會顯示為Unicode字元      ENCODING = 'utf-8'        # 自定義的配置項      ADDRESS_LABELS = ["COUNTY", "STREET", "COMMUNITY", "ROAD", "NUM", "POI", "CITY", "VILLAGE"]      PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]    class DevelopmentConfig(BaseConfig):      ENV = 'development'      DEBUG = True    class TestingConfig(BaseConfig):      TESTING = True      WTF_CSRF_ENABLED = False    class ProductionConfig(BaseConfig):      DEBUG = False    config = {      'testing': TestingConfig,      'default': DevelopmentConfig      # 'default': ProductionConfig  }
接著,創建人名識別介面文件person_ner_resource.py,內容如下:
# -*- coding: utf-8 -*-    """  命名實體識別介面  """  from entity_extractor import predict    __author__ = '程式設計師一一滌生'    from flask import Blueprint, make_response, request, current_app  from flask import jsonify  person = Blueprint('person', __name__)    model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = None, None, None, None, None, None, None, None, None, None, None  @person.route('/extract', methods=['POST'])    def extract():      params = request.get_json()      if 't' not in params or params['t'] is None or len(params['t']) > 500 or len(params['t']) < 2:          return make_response(jsonify({'error': '文本長度不符合要求,長度限制:2~500'}), 400)      sentence = params['t']      # 成句      sentence = sentence + "。" if not sentence.endswith((",", "。", "!", "?")) else sentence      # 利用模型提取      pred_rs, pred_label_result = predict(sentence, current_app.config['PERSON_LABELS'], model_dir, batch_size, id2label,                                           label_list, graph, input_ids_p,                                           input_mask_p,                                           pred_ids, tokenizer, sess, max_seq_length)      print(sentence)      return jsonify(pred_rs)    if __name__ == '__main__':      pass
接著,將requirements.txt文件放到項目根路徑下,文件內容如下:
absl-py==0.7.0  astor==0.7.1  backcall==0.1.0  backports.weakref==1.0rc1  bleach==1.5.0  certifi==2016.2.28  click==6.7  colorama==0.4.1  colorful==0.5.0  decorator==4.3.2  defusedxml==0.5.0  entrypoints==0.3  Flask==1.0.2  Flask-Cors==3.0.3  gast==0.2.2  grpcio==1.18.0  h5py==2.9.0  html5lib==0.9999999  ipykernel==5.1.0  ipython==7.2.0  ipython-genutils==0.2.0  ipywidgets==7.4.2  itsdangerous==0.24  jedi==0.13.2  Jinja2==2.10  jsonschema==2.6.0  jupyter==1.0.0  jupyter-client==5.2.4  jupyter-console==6.0.0  jupyter-core==4.4.0  Keras-Applications==1.0.6  Keras-Preprocessing==1.0.5  Markdown==3.0.1  MarkupSafe==1.1.0  mistune==0.8.4  mock==3.0.5  nbconvert==5.4.0  nbformat==4.4.0  notebook==5.7.4  numpy==1.16.0  pandocfilters==1.4.2  parso==0.3.2  pickleshare==0.7.5  prettyprinter==0.17.0  prometheus-client==0.5.0  prompt-toolkit==2.0.8  protobuf==3.6.1  Pygments==2.3.1  python-dateutil==2.7.5  pywinpty==0.5.5  pyzmq==17.1.2  qtconsole==4.4.3  Send2Trash==1.5.0  six==1.12.0  tensorboard==1.13.1  tensorflow==1.13.1  tensorflow-estimator==1.13.0  termcolor==1.1.0  terminado==0.8.1  testpath==0.4.2  tornado==5.1.1  traitlets==4.3.2  wcwidth==0.1.7  Werkzeug==0.14.1  widgetsnbextension==3.4.2  wincertstore==0.2
然後,執行如下命令,安裝requirements.txt中的包:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt

以上步驟完成後,我們就可以嘗試啟動項目了。

啟動項目

運行如下命令,啟動該flask項目:

python nlp_main.py

調用介面

本文使用postman來調用命名實體提取介面,介面地址:

http://localhost:5006/person/extract

調用效果展示:

注意,在cpu上使用模型的時間大概在2到3秒,而如果項目部署在搭載了支援深度學習的GPU的電腦上,介面的返回會快很多很多,當然不要忘記將tensorflow改為安裝tensorflow-gpu。

ok,本篇就這麼多內容,到此,我們已經基於深度學習開發了一個可以從自然語言中提取出人名、地址、組織、公司、產品、時間的項目,從下一篇開始,我們將介紹本項目使用的深度學習演算法Bert和crf,通過對演算法的了解,我們將更好的理解為什麼模型能夠準確的從句子中提取出我們想要的實體。

本篇就這麼多內容啦~,感謝閱讀O(∩_∩)O,88~