用深度学习做命名实体识别(五)-模型使用
- 2019 年 10 月 6 日
- 筆記
通过本文,你将了解如何基于训练好的模型,来编写一个rest风格的命名实体提取接口,传入一个句子,接口会提取出句子中的人名、地址、组织、公司、产品、时间信息并返回。
核心模块entity_extractor.py
关键函数
# 加载实体识别模型 def person_model_init(): ... # 预测句子中的实体 def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length): ...
完整代码
# -*- coding: utf-8 -*- """ 基于模型的地址提取 """ __author__ = '程序员一一涤生' import codecs import os import pickle from datetime import datetime from pprint import pprint import numpy as np import tensorflow as tf from bert_base.bert import tokenization, modeling from bert_base.train.models import create_model, InputFeatures from bert_base.train.train_helper import get_args_parser args = get_args_parser() def convert(line, model_dir, label_list, tokenizer, batch_size, max_seq_length): feature = convert_single_example(model_dir, 0, line, label_list, max_seq_length, tokenizer, 'p') input_ids = np.reshape([feature.input_ids], (batch_size, max_seq_length)) input_mask = np.reshape([feature.input_mask], (batch_size, max_seq_length)) segment_ids = np.reshape([feature.segment_ids], (batch_size, max_seq_length)) label_ids = np.reshape([feature.label_ids], (batch_size, max_seq_length)) return input_ids, input_mask, segment_ids, label_ids def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length): with graph.as_default(): start = datetime.now() # print(id2label) sentence = tokenizer.tokenize(sentence) # print('your input is:{}'.format(sentence)) input_ids, input_mask, segment_ids, label_ids = convert(sentence, model_dir, label_list, tokenizer, batch_size, max_seq_length) feed_dict = {input_ids_p: input_ids, input_mask_p: input_mask} # run session get current feed_dict result pred_ids_result = sess.run([pred_ids], feed_dict) pred_label_result = convert_id_to_label(pred_ids_result, id2label, batch_size) # print(pred_ids_result) print(pred_label_result) # todo: 组合策略 result = strage_combined(sentence, pred_label_result[0], labels_config) print('time used: {} sec'.format((datetime.now() - start).total_seconds())) return result, pred_label_result def convert_id_to_label(pred_ids_result, idx2label, batch_size): """ 将id形式的结果转化为真实序列结果 :param pred_ids_result: :param idx2label: :return: """ result = [] for row in range(batch_size): curr_seq = [] for ids in pred_ids_result[row][0]: if ids == 0: break curr_label = idx2label[ids] if curr_label in ['[CLS]', '[SEP]']: continue curr_seq.append(curr_label) result.append(curr_seq) return result def strage_combined(tokens, tags, labels_config): """ 组合策略 :param pred_label_result: :param types: :return: """ def get_output(rs, data, type): words = [] for i in data: words.append(str(i.word).replace("#", "")) # words.append(i.word) rs[type] = words return rs eval = Result(labels_config) if len(tokens) > len(tags): tokens = tokens[:len(tags)] labels_dict = eval.get_result(tokens, tags) arr = [] for k, v in labels_dict.items(): arr.append((k, v)) rs = {} for item in arr: rs = get_output(rs, item[1], item[0]) return rs def convert_single_example(model_dir, ex_index, example, label_list, max_seq_length, tokenizer, mode): """ 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中 :param ex_index: index :param example: 一个样本 :param label_list: 标签列表 :param max_seq_length: :param tokenizer: :param mode: :return: """ label_map = {} # 1表示从1开始对label进行index化 for (i, label) in enumerate(label_list, 1): label_map[label] = i # 保存label->index 的map if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')): with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w: pickle.dump(label_map, w) tokens = example # tokens = tokenizer.tokenize(example.text) # 序列截断 if len(tokens) >= max_seq_length - 1: tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志 ntokens = [] segment_ids = [] label_ids = [] ntokens.append("[CLS]") # 句子开始设置CLS 标志 segment_ids.append(0) # append("O") or append("[CLS]") not sure! label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病 for i, token in enumerate(tokens): ntokens.append(token) segment_ids.append(0) label_ids.append(0) ntokens.append("[SEP]") # 句尾添加[SEP] 标志 segment_ids.append(0) # append("O") or append("[SEP]") not sure! label_ids.append(label_map["[SEP]"]) input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式 input_mask = [1] * len(input_ids) # padding, 使用 while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) # we don't concerned about it! label_ids.append(0) ntokens.append("**NULL**") # label_mask.append(0) # print(len(input_ids)) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(label_ids) == max_seq_length # assert len(label_mask) == max_seq_length # 结构化为一个类 feature = InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids, # label_mask = label_mask ) return feature class Pair(object): def __init__(self, word, start, end, type, merge=False): self.__word = word self.__start = start self.__end = end self.__merge = merge self.__types = type @property def start(self): return self.__start @property def end(self): return self.__end @property def merge(self): return self.__merge @property def word(self): return self.__word @property def types(self): return self.__types @word.setter def word(self, word): self.__word = word @start.setter def start(self, start): self.__start = start @end.setter def end(self, end): self.__end = end @merge.setter def merge(self, merge): self.__merge = merge @types.setter def types(self, type): self.__types = type def __str__(self) -> str: line = [] line.append('entity:{}'.format(self.__word)) line.append('start:{}'.format(self.__start)) line.append('end:{}'.format(self.__end)) line.append('merge:{}'.format(self.__merge)) line.append('types:{}'.format(self.__types)) return 't'.join(line) class Result(object): def __init__(self, labels_config): self.others = [] self.labels_config = labels_config self.labels = {} for la in self.labels_config: self.labels[la] = [] def get_result(self, tokens, tags): # 先获取标注结果 self.result_to_json(tokens, tags) return self.labels def result_to_json(self, string, tags): """ 将模型标注序列和输入序列结合 转化为结果 :param string: 输入序列 :param tags: 标注结果 :return: """ item = {"entities": []} entity_name = "" entity_start = 0 idx = 0 last_tag = '' for char, tag in zip(string, tags): if tag[0] == "S": self.append(char, idx, idx + 1, tag[2:]) item["entities"].append({"word": char, "start": idx, "end": idx + 1, "type": tag[2:]}) elif tag[0] == "B": if entity_name != '': self.append(entity_name, entity_start, idx, last_tag[2:]) item["entities"].append( {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) entity_name = "" entity_name += char entity_start = idx elif tag[0] == "I": entity_name += char elif tag[0] == "O": if entity_name != '': self.append(entity_name, entity_start, idx, last_tag[2:]) item["entities"].append( {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) entity_name = "" else: entity_name = "" entity_start = idx idx += 1 last_tag = tag if entity_name != '': self.append(entity_name, entity_start, idx, last_tag[2:]) item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) return item def append(self, word, start, end, tag): if tag in self.labels_config: self.labels[tag].append(Pair(word, start, end, tag)) else: self.others.append(Pair(word, start, end, tag)) def person_model_init(): return model_init("person") def model_init(model_name): if os.name == 'nt': # windows path config model_dir = 'E:/quickstart/deeplearning/nlp_demo/%s/model' % model_name bert_dir = 'E:/quickstart/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12' else: # linux path config model_dir = '/home/yjy/project/deeplearning/nlp_demo/%s/model' % model_name bert_dir = '/home/yjy/project/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12' batch_size = 1 max_seq_length = 500 print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint"))) if not os.path.exists(os.path.join(model_dir, "checkpoint")): raise Exception("failed to get checkpoint. going to return ") # 加载label->id的词典 with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf: label2id = pickle.load(rf) id2label = {value: key for key, value in label2id.items()} with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf: label_list = pickle.load(rf) num_labels = len(label_list) + 1 gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True graph = tf.Graph() sess = tf.Session(graph=graph, config=gpu_config) with graph.as_default(): print("going to restore checkpoint") # sess.run(tf.global_variables_initializer()) input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask") bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json')) (total_loss, logits, trans, pred_ids) = create_model( bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None, labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(model_dir)) tokenizer = tokenization.FullTokenizer( vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=args.do_lower_case) return model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length if __name__ == "__main__": _model_dir, _batch_size, _id2label, _label_list, _graph, _input_ids_p, _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length = person_model_init() PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"] while True: print('input the test sentence:') _sentence = str(input()) pred_rs, pred_label_result = predict(_sentence, ADDRESS_LABELS, _model_dir, _batch_size, _id2label, _label_list, _graph, _input_ids_p, _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length) pprint(pred_rs)
编写rest风格的接口
我们将采用python的flask框架来提供rest接口。
首先,新建一个python项目,项目根路径下放入以下目录和文件:
- bert_base目录及文件、bert_model_info目录及文件在上一篇文章 用深度学习做命名实体识别(四)——模型训练 给出的云盘项目中可以找到;
- person目录下的model就是我们在上一篇文章中训练得到的命名实体识别模型以及一些附属文件,在项目的output目录下可以得到。
然后,创建启动文件nlp_main.py,内容如下:
# -*- coding: utf-8 -*- """ flask 入口 """ import os import nlp_config as nc from flaskr import create_app, loadProjContext __author__ = '程序员一一涤生' from flask import jsonify, make_response, redirect # 加载flask配置信息 # app = create_app('config.DevelopmentConfig') app = create_app(nc.config['default']) # 加载项目上下文信息 loadProjContext() @app.errorhandler(404) def not_found(error): return make_response(jsonify({'error': 'Not found'}), 404) @app.errorhandler(400) def not_found(error): return make_response(jsonify({'error': '400 Bad Request,参数或参数内容异常'}), 400) @app.route('/') def index_sf(): # return render_template('index.html') return redirect('index.html') if __name__ == '__main__': app.run('localhost', 5006, app, use_reloader=False)
接着,创建本flask项目的初始化文件flaskr.py,用于启动项目的时候预设置和加载一些信息,内容如下:
# -*- coding: utf-8 -*- """ flask初始化 """ from logging.config import dictConfig from flask import Flask from flask_cors import CORS import address_ner_resource import person_ner_resource from address_ner_resource import address from entity_extractor import address_model_init, person_model_init from person_ner_resource import person __author__ = '程序员一一涤生' def create_app(config_type): dictConfig({ 'version': 1, 'formatters': {'default': { 'format': '[%(asctime)s] %(name)s %(levelname)s in %(module)s %(lineno)d: %(message)s', }}, 'handlers': {'wsgi': { 'class': 'logging.StreamHandler', 'stream': 'ext://flask.logging.wsgi_errors_stream', 'formatter': 'default' }}, 'root': { 'level': 'DEBUG', # 'level': 'WARN', # 'level': 'INFO', 'handlers': ['wsgi'] } }) # 加载flask配置信息 app = Flask(__name__, static_folder='static', static_url_path='') # CORS(app, resources=r'/*',origins=['192.168.1.104']) # r'/*' 是通配符,允许跨域请求本服务器所有的URL,"origins": '*'表示允许所有ip跨域访问本服务器的url CORS(app, resources={r"/*": {"origins": '*'}}) # r'/*' 是通配符,允许跨域请求本服务器所有的URL,"origins": '*'表示允许所有ip跨域访问本服务器的url app.config.from_object(config_type) app.register_blueprint(person, url_prefix='/person') # 初始化上下文 ctx = app.app_context() ctx.push() return app def loadProjContext(): # 加载人名提取模型 model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = person_model_init() person_ner_resource.model_dir = model_dir person_ner_resource.batch_size = batch_size person_ner_resource.id2label = id2label person_ner_resource.label_list = label_list person_ner_resource.graph = graph person_ner_resource.input_ids_p = input_ids_p person_ner_resource.input_mask_p = input_mask_p person_ner_resource.pred_ids = pred_ids person_ner_resource.tokenizer = tokenizer person_ner_resource.sess = sess person_ner_resource.max_seq_length = max_seq_length
然后,创建配置文件nlp_config.py,用于切换生产、开发、测试环境,内容如下:
# -*- coding: utf-8 -*- """ 本模块是Flask的配置模块 """ import os __author__ = '程序员一一涤生' basedir = os.path.abspath(os.path.dirname(__file__)) class BaseConfig: # 基本配置类 SECRET_KEY = b'xe4rx04xb5xb2x00xf1xadfxa3xf3Vx03xc5x9fx82$^xa25Oxf0Rxda' JSONIFY_MIMETYPE = 'application/json; charset=utf-8' # 默认JSONIFY_MIMETYPE的配置是不带'; charset=utf-8的' JSON_AS_ASCII = False # 若不关闭,使用JSONIFY返回json时中文会显示为Unicode字符 ENCODING = 'utf-8' # 自定义的配置项 ADDRESS_LABELS = ["COUNTY", "STREET", "COMMUNITY", "ROAD", "NUM", "POI", "CITY", "VILLAGE"] PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"] class DevelopmentConfig(BaseConfig): ENV = 'development' DEBUG = True class TestingConfig(BaseConfig): TESTING = True WTF_CSRF_ENABLED = False class ProductionConfig(BaseConfig): DEBUG = False config = { 'testing': TestingConfig, 'default': DevelopmentConfig # 'default': ProductionConfig }
接着,创建人名识别接口文件person_ner_resource.py,内容如下:
# -*- coding: utf-8 -*- """ 命名实体识别接口 """ from entity_extractor import predict __author__ = '程序员一一涤生' from flask import Blueprint, make_response, request, current_app from flask import jsonify person = Blueprint('person', __name__) model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = None, None, None, None, None, None, None, None, None, None, None @person.route('/extract', methods=['POST']) def extract(): params = request.get_json() if 't' not in params or params['t'] is None or len(params['t']) > 500 or len(params['t']) < 2: return make_response(jsonify({'error': '文本长度不符合要求,长度限制:2~500'}), 400) sentence = params['t'] # 成句 sentence = sentence + "。" if not sentence.endswith((",", "。", "!", "?")) else sentence # 利用模型提取 pred_rs, pred_label_result = predict(sentence, current_app.config['PERSON_LABELS'], model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length) print(sentence) return jsonify(pred_rs) if __name__ == '__main__': pass
接着,将requirements.txt文件放到项目根路径下,文件内容如下:
absl-py==0.7.0 astor==0.7.1 backcall==0.1.0 backports.weakref==1.0rc1 bleach==1.5.0 certifi==2016.2.28 click==6.7 colorama==0.4.1 colorful==0.5.0 decorator==4.3.2 defusedxml==0.5.0 entrypoints==0.3 Flask==1.0.2 Flask-Cors==3.0.3 gast==0.2.2 grpcio==1.18.0 h5py==2.9.0 html5lib==0.9999999 ipykernel==5.1.0 ipython==7.2.0 ipython-genutils==0.2.0 ipywidgets==7.4.2 itsdangerous==0.24 jedi==0.13.2 Jinja2==2.10 jsonschema==2.6.0 jupyter==1.0.0 jupyter-client==5.2.4 jupyter-console==6.0.0 jupyter-core==4.4.0 Keras-Applications==1.0.6 Keras-Preprocessing==1.0.5 Markdown==3.0.1 MarkupSafe==1.1.0 mistune==0.8.4 mock==3.0.5 nbconvert==5.4.0 nbformat==4.4.0 notebook==5.7.4 numpy==1.16.0 pandocfilters==1.4.2 parso==0.3.2 pickleshare==0.7.5 prettyprinter==0.17.0 prometheus-client==0.5.0 prompt-toolkit==2.0.8 protobuf==3.6.1 Pygments==2.3.1 python-dateutil==2.7.5 pywinpty==0.5.5 pyzmq==17.1.2 qtconsole==4.4.3 Send2Trash==1.5.0 six==1.12.0 tensorboard==1.13.1 tensorflow==1.13.1 tensorflow-estimator==1.13.0 termcolor==1.1.0 terminado==0.8.1 testpath==0.4.2 tornado==5.1.1 traitlets==4.3.2 wcwidth==0.1.7 Werkzeug==0.14.1 widgetsnbextension==3.4.2 wincertstore==0.2
然后,执行如下命令,安装requirements.txt中的包:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
以上步骤完成后,我们就可以尝试启动项目了。
启动项目
运行如下命令,启动该flask项目:
python nlp_main.py
调用接口
本文使用postman来调用命名实体提取接口,接口地址:
http://localhost:5006/person/extract
调用效果展示:
注意,在cpu上使用模型的时间大概在2到3秒,而如果项目部署在搭载了支持深度学习的GPU的电脑上,接口的返回会快很多很多,当然不要忘记将tensorflow改为安装tensorflow-gpu。
ok,本篇就这么多内容,到此,我们已经基于深度学习开发了一个可以从自然语言中提取出人名、地址、组织、公司、产品、时间的项目,从下一篇开始,我们将介绍本项目使用的深度学习算法Bert和crf,通过对算法的了解,我们将更好的理解为什么模型能够准确的从句子中提取出我们想要的实体。
本篇就这么多内容啦~,感谢阅读O(∩_∩)O,88~