1. 磐创AI首页
  2. Medium

使用TensorFlow对象检测API训练自定义对象检测模型


出发地:https://makeoptim.com/en/deep-learning/yiai-object-detectionhttps://makeoptim.com/en/deep-learning/yiai-object-detection

前言

本文将介绍对象检测的概念,并通过案例说明如何使用TensorFlow对象检测API训练一个自定义的对象检测器,包括数据集的采集和处理、TensorFlow对象检测API的安装、模型训练。TensorFlow Object Detection API

案例效果如下图所示:

目标检测


如上图所示,图像分类解决的问题是图片中的对象是什么,而对象检测可以识别图片中的对象和对象的位置(坐标)。

位置

目标检测的位置信息一般有两种格式:

  • 极坐标(xmin,ymin,xmax,ymax):xmin,ymin:x,y坐标的最小值;xmin,ymin:x,y坐标的最大值
  • CENTER POINT:(x_center,y_center,w,h):x_center,y_center:目标检测框的中心点坐标;w,h:目标检测框的宽度和高度

里程碑

传统方法(区域建议+人工特征提取+分类器)

HOG+支持向量机、DPM

地区建议书+CNN(两阶段)

R-CNN、SPP-NET、快速R-CNN、更快R-CNN

端到端(一阶段)

YOLO、固态硬盘

TensorFlow对象检测API

TensorFlow对象检测API是一个构建在TensorFlow之上的开源框架,它使得构建、训练和部署对象检测模型变得很容易。此外,TensorFlow对象检测API还提供了Model Zoo,以方便我们选择和切换预先训练的模型。TensorFlow Object Detection API TensorFlow Object Detection API Model Zoo

安装依赖项

  • 孔达
  • 协议

使用以下命令检查安装是否成功。

$ conda --version
conda 4.9.2
$ protoc --version
libprotoc 3.17.1

安装API

TensorFlow对象检测API提供的官方安装步骤比较繁琐。作者编写了一个脚本,只需一步即可直接安装。TensorFlow Object Detection API

执行git克隆https://github.com/CatchZeng/object-detection-api.git下载repo,然后转到repo所在的目录(以下称为ODA repo),如果您看到以下输出,则执行以下命令,表明安装成功。

$  conda create -n  od python=3.8.5 && conda activate od && make install
$ pip install --upgrade tf-models-official==2.4.0
$ pip install --upgrade tensorflow==2.4.1

创建工作区

$ conda activate od
$ conda env list
# conda environments:
#
od * /Users/catchzeng/.conda/envs/od
tensorflow /Users/catchzeng/.conda/envs/tensorflow
base /Users/catchzeng/miniconda3

转到ODA repo目录并执行以下命令以创建工作区目录结构。

$ make workspace-box SAVE_DIR=workspace NAME=test

数据集

图像

我喜欢喝茶。今天我将以杯子、茶壶、加湿器为例。

将采集到的图片放入项目目录中图片的三个子目录中。

注解

收集图片后,您需要对训练和评估集中的图像进行注释。

我们选择LabelImg作为注释工具。LabelImg

按照安装说明安装LabelImg,然后执行labelImg选择要注释的Train和Val文件夹。installation

批注完成后,将生成图片对应的XML批注文件,如下图所示:

workspace/test/images
├── test
│ ├── 15.jpg
│ └── 16.jpg
├── train
│ ├── 1.jpg
│ ├── 1.xml
│ ├── 10.jpg
│ ├── 10.xml
│ ├── 2.jpg
│ ├── 2.xml
│ ├── 3.jpg
│ ├── 3.xml
│ ├── 4.jpg
│ ├── 4.xml
│ ├── 5.jpg
│ ├── 5.xml
│ ├── 6.jpg
│ ├── 6.xml
│ ├── 7.jpg
│ ├── 7.xml
│ ├── 8.jpg
│ ├── 8.xml
│ ├── 9.jpg
│ └── 9.xml
└── val
├── 11.jpg
├── 11.xml
├── 12.jpg
├── 12.xml
├── 13.jpg
├── 13.xml
├── 14.jpg
└── 14.xml

LabelMap

在文件夹workspace/test/notation下创建label_map.pbtxt,内容是模型需要识别的对象。

item {
id: 1
name: 'cup'
}

创建TFRecord

TensorFlow Object Detection接口仅支持TFRecord格式,因此需要对数据集进行转换。TensorFlow Object Detection API TFRecord

转到工作区目录(cd workspace/test),然后执行make gen-tford,它将在Annotation文件夹中生成TFRecord格式的数据集。TFRecord

$ make gen-tfrecord
python ../../scripts/preprocessing/generate_tfrecord.py \
-x images/train \
-l annotations/label_map.pbtxt \
-o annotations/train.record
Successfully created the TFRecord file: annotations/train.record
python ../../scripts/preprocessing/generate_tfrecord.py \
-x images/val \
-l annotations/label_map.pbtxt \
-o annotations/val.record
Successfully created the TFRecord file: annotations/val.record

模范训练

下载预先训练好的模型

从Model Zoo中选择合适的模型,下载并解压缩,然后将其放入工作区/测试/预先训练的模型中。Model Zoo

如果选择SSD MobileNet V2 FPNLite 320×320,可以执行以下命令自动下载和解压缩

$ make dl-model

目录结构如下:

└── test
└── pre-trained-models
└── ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8
├── checkpoint
├── pipeline.config
└── saved_model

配置培训管道

在Models目录中创建相应的模型文件夹,例如:ssd_mobilenet_v2_fpnlite_320x320,复制pre-trained-models/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8/pipeline.config.

└── test
├── models
│ └── ssd_mobilenet_v2_fpnlite_320x320
│ └── pipeline.config
└── pre-trained-models

其中,Pipeline.config需要根据项目进行修改,具体如下

model {
ssd {
num_classes: 3 # Modify to the number of objects that need to be identified.
......
}
train_config {
batch_size: 8 # Here you need to adjust the size according to your own computer performance
......
optimizer {
momentum_optimizer {
learning_rate {
cosine_decay_learning_rate {
learning_rate_base: 0.07999999821186066
total_steps: 10000 # Modify to the total number of steps you want to train
warmup_learning_rate: 0.026666000485420227
warmup_steps: 1000
}
}
momentum_optimizer_value: 0.8999999761581421
}
use_moving_average: false
}
fine_tune_checkpoint: "pre-trained-models/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8/checkpoint/ckpt-0" # Modify the path to the pre-trained model
num_steps: 10000 # Modify to the total number of steps you want to train
startup_delay_steps: 0.0
replicas_to_aggregate: 8
max_number_of_boxes: 100
unpad_groundtruth_tensors: false
fine_tune_checkpoint_type: "detection" # Here needs to be modified to detection, because we are doing object detection
fine_tune_checkpoint_version: V2
}
train_input_reader {
label_map_path: "annotations/label_map.pbtxt" # Modify to the annotations path
tf_record_input_reader {
input_path: "annotations/train.record" # Modify the path to the training set
}
}
eval_config {
metrics_set: "coco_detection_metrics"
use_moving_averages: false
}
eval_input_reader {
label_map_path: "annotations/label_map.pbtxt" # Modify to the annotations path
shuffle: false
num_epochs: 1
tf_record_input_reader {
input_path: "annotations/val.record" # Modify the path to the evaluation set
}
}

培训模式

$ make train

模型导出和转换

  • 保存的模型

$进行导出

  • TFLite模型

$make export-lite

  • 转换TFLite模型

$make Convert-Lite

  • 量化TFLite模型

$make Convert-Quant-Lite

测试

在执行make export以导出模型之后,将测试图像放在image/test文件夹中,然后执行python test_images.py将带注释的图像输出到image/test_Annotated。

摘要

本文通过案例介绍了物体检测的全过程,希望能帮助您快速掌握培养自定义物体检测器的能力。

案例的代码和数据集已放置在https://github.com/CatchZeng/object-detection-api.中https://github.com/CatchZeng/object-detection-api

下面的文章将向您介绍对象检测的原理、流行的对象检测网络和图像分割。这篇文章就到这里,下次见。

参考文献

  • https://github.com/tensorflow/models/tree/master/research/object_detection
  • https://arxiv.org/pdf/1905.05055.pdf
  • https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md

原创文章,作者:fendouai,如若转载,请注明出处:https://panchuang.net/2021/09/20/%e4%bd%bf%e7%94%a8tensorflow%e5%af%b9%e8%b1%a1%e6%a3%80%e6%b5%8bapi%e8%ae%ad%e7%bb%83%e8%87%aa%e5%ae%9a%e4%b9%89%e5%af%b9%e8%b1%a1%e6%a3%80%e6%b5%8b%e6%a8%a1%e5%9e%8b/

联系我们

400-800-8888

在线咨询:点击这里给我发消息

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息