python+opencv实现目标跟踪

  • 2020 年 1 月 10 日
  • 筆記

python-opencv3.0新增了一些比较有用的追踪器算法,这里根据官网示例写了一个追踪器类

程序只能运行在安装有opencv3.0以上版本和对应的contrib模块的python解释器

#encoding=utf-8    import cv2  from items import MessageItem  import time  import numpy as np  '''  监视者模块,负责入侵检测,目标跟踪  '''  class WatchDog(object):    #入侵检测者模块,用于入侵检测      def __init__(self,frame=None):          #运动检测器构造函数          self._background = None          if frame is not None:              self._background = cv2.GaussianBlur(cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY),(21,21),0)          self.es = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))      def isWorking(self):          #运动检测器是否工作          return self._background is not None      def startWorking(self,frame):          #运动检测器开始工作          if frame is not None:              self._background = cv2.GaussianBlur(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY), (21, 21), 0)      def stopWorking(self):          #运动检测器结束工作          self._background = None      def analyze(self,frame):          #运动检测          if frame is None or self._background is None:              return          sample_frame = cv2.GaussianBlur(cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY),(21,21),0)          diff = cv2.absdiff(self._background,sample_frame)          diff = cv2.threshold(diff, 25, 255, cv2.THRESH_BINARY)[1]          diff = cv2.dilate(diff, self.es, iterations=2)          image, cnts, hierarchy = cv2.findContours(diff.copy(),cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)          coordinate = []          bigC = None          bigMulti = 0          for c in cnts:              if cv2.contourArea(c) < 1500:                  continue              (x,y,w,h) = cv2.boundingRect(c)              if w * h > bigMulti:                  bigMulti = w * h                  bigC = ((x,y),(x+w,y+h))          if bigC:              cv2.rectangle(frame, bigC[0],bigC[1], (255,0,0), 2, 1)          coordinate.append(bigC)          message = {"coord":coordinate}          message['msg'] = None          return MessageItem(frame,message)    class Tracker(object):      '''      追踪者模块,用于追踪指定目标      '''      def __init__(self,tracker_type = "BOOSTING",draw_coord = True):          '''          初始化追踪器种类          '''          #获得opencv版本          (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')          self.tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']          self.tracker_type = tracker_type          self.isWorking = False          self.draw_coord = draw_coord          #构造追踪器          if int(minor_ver) < 3:              self.tracker = cv2.Tracker_create(tracker_type)          else:              if tracker_type == 'BOOSTING':                  self.tracker = cv2.TrackerBoosting_create()              if tracker_type == 'MIL':                  self.tracker = cv2.TrackerMIL_create()              if tracker_type == 'KCF':                  self.tracker = cv2.TrackerKCF_create()              if tracker_type == 'TLD':                  self.tracker = cv2.TrackerTLD_create()              if tracker_type == 'MEDIANFLOW':                  self.tracker = cv2.TrackerMedianFlow_create()              if tracker_type == 'GOTURN':                  self.tracker = cv2.TrackerGOTURN_create()      def initWorking(self,frame,box):          '''          追踪器工作初始化          frame:初始化追踪画面          box:追踪的区域          '''          if not self.tracker:              raise Exception("追踪器未初始化")          status = self.tracker.init(frame,box)          if not status:              raise Exception("追踪器工作初始化失败")          self.coord = box          self.isWorking = True        def track(self,frame):          '''          开启追踪          '''          message = None          if self.isWorking:              status,self.coord = self.tracker.update(frame)              if status:                  message = {"coord":[((int(self.coord[0]), int(self.coord[1])),(int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3])))]}                  if self.draw_coord:                      p1 = (int(self.coord[0]), int(self.coord[1]))                      p2 = (int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3]))                      cv2.rectangle(frame, p1, p2, (255,0,0), 2, 1)                      message['msg'] = "is tracking"          return MessageItem(frame,message)    class ObjectTracker(object):      def __init__(self,dataSet):          self.cascade = cv2.CascadeClassifier(dataSet)      def track(self,frame):          gray = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)          faces = self.cascade.detectMultiScale(gray,1.03,5)          for (x,y,w,h) in faces:              cv2.rectangle(frame,(x,y),(x+w,y+h),(255,0,0),2)          return frame    if __name__ == '__main__' :      a = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']      tracker = Tracker(tracker_type="KCF")      video = cv2.VideoCapture(0)      ok, frame = video.read()      bbox = cv2.selectROI(frame, False)      tracker.initWorking(frame,bbox)      while True:          _,frame = video.read();          if(_):              item = tracker.track(frame);              cv2.imshow("track",item.getFrame())              k = cv2.waitKey(1) & 0xff              if k == 27:                  break
#encoding=utf-8  import json  from utils import IOUtil  '''  信息封装类  '''  class MessageItem(object):      #用于封装信息的类,包含图片和其他信息      def __init__(self,frame,message):          self._frame = frame          self._message = message      def getFrame(self):          #图片信息          return self._frame      def getMessage(self):          #文字信息,json格式          return self._message      def getBase64Frame(self):          #返回base64格式的图片,将BGR图像转化为RGB图像          jepg = IOUtil.array_to_bytes(self._frame[...,::-1])          return IOUtil.bytes_to_base64(jepg)      def getBase64FrameByte(self):          #返回base64格式图片的bytes          return bytes(self.getBase64Frame())      def getJson(self):          #获得json数据格式          dicdata = {"frame":self.getBase64Frame().decode(),"message":self.getMessage()}          return json.dumps(dicdata)      def getBinaryFrame(self):          return IOUtil.array_to_bytes(self._frame[...,::-1])

运行之后在第一帧图像上选择要追踪的部分,测试了一下使用KCF算法的追踪器

更新:忘记放utils,给大家造成的困扰深表歉意

#encoding=utf-8  import time  import numpy  import base64  import os  import logging  import sys  from settings import *  from PIL import Image  from io import BytesIO    #工具类  class IOUtil(object):      #流操作工具类      @staticmethod      def array_to_bytes(pic,formatter="jpeg",quality=70):          '''          静态方法,将numpy数组转化二进制流          :param pic: numpy数组          :param format: 图片格式          :param quality:压缩比,压缩比越高,产生的二进制数据越短          :return:          '''          stream = BytesIO()          picture = Image.fromarray(pic)          picture.save(stream,format=formatter,quality=quality)          jepg = stream.getvalue()          stream.close()          return jepg      @staticmethod      def bytes_to_base64(byte):          '''          静态方法,bytes转base64编码          :param byte:          :return:          '''          return base64.b64encode(byte)      @staticmethod      def transport_rgb(frame):          '''          将bgr图像转化为rgb图像,或者将rgb图像转化为bgr图像          '''          return frame[...,::-1]      @staticmethod      def byte_to_package(bytes,cmd,var=1):          '''          将每一帧的图片流的二进制数据进行分包          :param byte: 二进制文件          :param cmd:命令          :return:          '''          head = [ver,len(byte),cmd]          headPack = struct.pack("!3I", *head)          senddata = headPack+byte          return senddata      @staticmethod      def mkdir(filePath):          '''          创建文件夹          '''          if not os.path.exists(filePath):              os.mkdir(filePath)      @staticmethod      def countCenter(box):          '''          计算一个矩形的中心          '''          return (int(abs(box[0][0] - box[1][0])*0.5) + box[0][0],int(abs(box[0][1] - box[1][1])*0.5) +box[0][1])      @staticmethod      def countBox(center):          '''          根据两个点计算出,x,y,c,r          '''          return (center[0][0],center[0][1],center[1][0]-center[0][0],center[1][1]-center[0][1])      @staticmethod      def getImageFileName():          return time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())+'.png'    #构造日志  logger = logging.getLogger(LOG_NAME)  formatter = logging.Formatter(LOG_FORMATTER)  IOUtil.mkdir(LOG_DIR);  file_handler = logging.FileHandler(LOG_DIR + LOG_FILE,encoding='utf-8')  file_handler.setFormatter(formatter)  console_handler = logging.StreamHandler(sys.stdout)  console_handler.setFormatter(formatter)  logger.addHandler(file_handler)  logger.addHandler(console_handler)  logger.setLevel(logging.INFO)