本文介绍2024年新手掌握PaddleOCR的使用,包括其超轻量等特性。涵盖快速使用步骤,如解压数据集、安装环境、准备模型和测试图像,及单张和多张图像测试。还讲解训练数据集、文字检测与识别的训练和测试,模型转换,以及知识蒸馏的配置、训练等内容,最后提及多种部署选项。
☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

本教程旨在帮助使用者快速了解PaddleOCR,并掌握PaddleOCR的使用方式,包括:
本节介绍如何使用PaddleOCR的轻量级模型完成文本检测、识别的任务。
轻量级OCR模型效果:
!ls /home/aistudio/data
data294006
!mkdir /home/aistudio/external-libraries !pip install beautifulsoup4 -t /home/aistudio/external-libraries
%cd ~
/home/aistudio
!git clone https://github.com/PaddlePaddle/PaddleOCR.git
Cloning into 'PaddleOCR'... remote: Enumerating objects: 51224, done. remote: Counting objects: 100% (1542/1542), done. remote: Compressing objects: 100% (848/848), done. remote: Total 51224 (delta 808), reused 1315 (delta 680), pack-reused 49682 (from 1) Receiving objects: 100% (51224/51224), 385.22 MiB | 7.47 MiB/s, done. Resolving deltas: 100% (35953/35953), done. Updating files: 100% (2390/2390), done.
%cd ~/PaddleOCR
/home/aistudio/PaddleOCR
# 安装依赖库!pip install -r requirements.txt -i https://mirror.baidu.com/pypi/simple
!mkdir inference && cd inference# 下载8.6M中文模型的检测模型并解压# ! cd inference && wget https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar && tar xf ch_det_mv3_db_infer.tar# # 下载8.6M中文模型的识别模型并解压# ! cd inference && wget https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar && tar xf ch_rec_mv3_crnn_infer.tar!cd inference && wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_det_infer.tar !cd inference && wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
准备了一些测试图像,上传图像测试我们的OCR模型。使用代码可视化测试图像
import matplotlib.pyplot as pltfrom PIL import Image# 图像文件绝对路径img_path = "/home/aistudio/data/data294006/ppocr_img/imgs/11.jpg"# 打开图像文件img = Image.open(img_path)# 显示图像plt.figure("test_img", figsize=(10,10))
plt.imshow(img)
plt.show()<Figure size 1000x1000 with 1 Axes>
下面开始调用tools/infer/predict_system.py 完成图像文本识别,共需要传入三个参数:
image_dir: 指定要测试的图像
det_model_dir: 指定轻量检测模型的inference model
rec_model_dir: 指定轻量识别模型的inference model
!pwd !ls
/home/aistudio/PaddleOCR LICENSE applications docs paddleocr.py setup.py MANIFEST.in benchmark inference ppocr test_tipc README.md configs inference_results ppstructure tests README_en.md deploy mkdocs.yml pyproject.toml tools __init__.py doc overrides requirements.txt train.sh
# 快速运行!python3 tools/infer/predict_system.py --image_dir="/home/aistudio/data/data294006/ppocr_img/imgs/11.jpg" --det_model_dir="./inference/ch_PP-OCRv3_det_infer/" --rec_model_dir="./inference/ch_PP-OCRv3_rec_infer/"
[2024/09/11 22:51:22] ppocr WARNING: The first GPU is used for inference by default, GPU ID: 0 E0911 22:51:22.421043 11060 analysis_config.cc:126] Please use PaddlePaddle with GPU version. [2024/09/11 22:51:22] ppocr WARNING: The first GPU is used for inference by default, GPU ID: 0 E0911 22:51:22.950991 11060 analysis_config.cc:126] Please use PaddlePaddle with GPU version. [2024/09/11 22:51:23] ppocr INFO: In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320 [2024/09/11 22:51:23] ppocr DEBUG: dt_boxes num : 16, elapsed : 0.3190124034881592 [2024/09/11 22:51:24] ppocr DEBUG: rec_res num : 16, elapsed : 1.1868066787719727 [2024/09/11 22:51:24] ppocr DEBUG: 0 Predict time of /home/aistudio/data/data294006/ppocr_img/imgs/11.jpg: 1.511s [2024/09/11 22:51:24] ppocr DEBUG: 纯臻营养护发素, 0.966 [2024/09/11 22:51:24] ppocr DEBUG: 产品信息/参数, 0.911 [2024/09/11 22:51:24] ppocr DEBUG: (45元/每公斤,100公斤起订), 0.884 [2024/09/11 22:51:24] ppocr DEBUG: 每瓶22元,1000瓶起订), 0.921 [2024/09/11 22:51:24] ppocr DEBUG: 【品牌】:代加工方式/OEMODM, 0.966 [2024/09/11 22:51:24] ppocr DEBUG: 【品名】:纯臻营养护发素, 0.883 [2024/09/11 22:51:24] ppocr DEBUG: 【产品编号】:YM-X-3011, 0.872 [2024/09/11 22:51:24] ppocr DEBUG: ODMOEM, 0.954 [2024/09/11 22:51:24] ppocr DEBUG: 【净含量】:220ml, 0.935 [2024/09/11 22:51:24] ppocr DEBUG: 适用人群):适合所有肤质, 0.887 [2024/09/11 22:51:24] ppocr DEBUG: 【主要成分】:鲸蜡硬脂醇、燕麦β-葡聚, 0.925 [2024/09/11 22:51:24] ppocr DEBUG: 糖、椰油xian胺bing基甜菜碱、泛醒, 0.937 [2024/09/11 22:51:24] ppocr DEBUG: (成品包材), 0.893 [2024/09/11 22:51:24] ppocr DEBUG: 【主要功能】:可紧致头发磷层,从而达到, 0.868 [2024/09/11 22:51:24] ppocr DEBUG: 即时持久改善头发光泽的效果,给于燥的头, 0.880 [2024/09/11 22:51:24] ppocr DEBUG: 发足够的滋养, 0.909 [2024/09/11 22:51:25] ppocr DEBUG: The visualized image saved in ./inference_results/11.jpg [2024/09/11 22:51:25] ppocr INFO: The predict total time is 1.6732959747314453
输出结果中有两列数据,第一列表示PaddleOCR识别出的文字,第二列表示识别出当前文字的置信度。置信度的数据范围是[0-1],置信度越接近1表示文本识别对的“信心”越大。
如下输出结果中:
[2024/09/11 22:51:24] ppocr DEBUG: 【主要成分】:鲸蜡硬脂醇、燕麦β-葡聚, 0.925
[2024/09/11 22:51:24] ppocr DEBUG: 糖、椰油xian胺bing基甜菜碱、泛醒, 0.937
[2024/09/11 22:51:24] ppocr DEBUG: (成品包材), 0.893
[2024/09/11 22:51:24] ppocr DEBUG: 【主要功能】:可紧致头发磷层,从而达到, 0.868
[2024/09/11 22:51:24] ppocr DEBUG: 即时持久改善头发光泽的效果,给于燥的头, 0.880
同时,识别结果会可视化在图像中并保存在./inference_results文件夹下,可以通过左边的目录结构选择要打开的文件, 也可以通过如下代码将可视化后的图像显示出来,观察OCR文本识别的效果。
image_dir支持传入单张图像和图像所在的文件目录,当image_dir指定的是图像目录时,运行上述指令会预测当前文件夹下的所有图像中的文字,并将预测的可视化结果保存在inference_results文件夹下。
# 快速运行!python3 tools/infer/predict_system.py --image_dir="/home/aistudio/data/data294006/ppocr_img/imgs/" --det_model_dir="./inference/ch_PP-OCRv3_det_infer/" --rec_model_dir="./inference/ch_PP-OCRv3_rec_infer/"
## 显示轻量级模型识别结果## 可视化的文本识别效果img_path= "/home/aistudio/data/data294006/ppocr_img/imgs/11.jpg"img = Image.open(img_path)
plt.figure("results_img", figsize=(5,5))
plt.imshow(img)
plt.show()<Figure size 500x500 with 1 Axes>
## 显示轻量级模型识别结果img_path= "./inference_results/11.jpg"img = Image.open(img_path)
plt.figure("results_img", figsize=(10,30))
plt.imshow(img)
plt.show()<Figure size 1000x3000 with 1 Axes>
%cd ~/PaddleOCR/# 根据backbone的不同选择下载对应的预训练模型# 下载MobileNetV3的预训练模型!wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams# 或,下载ResNet18_vd的预训练模型!wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams# 或,下载ResNet50_vd的预训练模型!wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
全部打标完成之后,点击文件选择导出标记结果,再点击文件选择导出识别结果,完成后再文件夹多出四个文件fileState,Label,rec_gt, crop_img。其中crop_img中的图片用来训练文字识别模型,
fileState记录图片的打标完成与否,Label为训练文字检测模型的标签,rec_gt为训练文字识别模型的标签。
6:2:2表示将数据集按照 6:2:2 的比例分割训练集、验证集和测试集,即训练集占总数据集的60%,验证集占20%,测试集占20%。
!rm -rf /home/aistudio/PaddleOCR/train_data !ln -s /home/aistudio/data/data264561/train_data /home/aistudio/PaddleOCR/ %cd /home/aistudio/PaddleOCR/PPOCRLabel !python3 gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data/drivingData/corrected_images
%cd /home/aistudio/PaddleOCR
!pip install visualdl
!python3 tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml
!visualdl --logdir "output/ch_db_res18/vdl"
模型训练完之后会在文件夹下保存训练好的模型,具体保存的文件夹看配置文件,这就是模型保存的路径:save_model_dir: ./output/ch_db_res18/
使用best_accuracy.pdparams进行模型测试。
在终端中输入以下指令进行测试。其中Global.pretrained_model是我们训练好并且需要测试的模型,Global.infer_img为所要检测的图片路径
!python tools/infer_det.py \
-c configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml \
-o Global.pretrained_model=/home/aistudio/PaddleOCR/output/ch_db_res18/latest.pdparams \
Global.infer_img=/home/aistudio/data/data264561/train_data/det/test/5320_corrected.jpg!python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml
!python tools/infer_rec.py \
-c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml \
-o Global.pretrained_model=/home/aistudio/PaddleOCR/output/rec_ppocr_v3/latest.pdparams \
Global.infer_img=/home/aistudio/data/data264561/train_data/det/test/5325_corrected.jpg其中Global.pretrained_model是我们训练好并且需要推理的模型,Global.save_inference_dir为所要保存推理模型的位置。推理模型是可以直接被调用进行识别和检测。分别把训练好的文字检测模型和文字识别模型推理。
!python tools/export_model.py \
-c "./configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml" \
-o Global.pretrained_model=/home/aistudio/PaddleOCR/output/rec_ppocr_v3/latest.pdparams \
Global.save_inference_dir="./inference_model/rec/"!python tools/export_model.py \
-c "./configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml" \
-o Global.pretrained_model=/home/aistudio/PaddleOCR/output/ch_db_res18/latest.pdparams \
Global.save_inference_dir="./inference_model/det/"其中det和rec是推理模型,可以用predict_system.py进行验证
!python tools/infer/predict_system.py \
--image_dir=/home/aistudio/data/data264561/train_data/det/test/5320_corrected.jpg \
--det_model_dir="./inference_model/det/" \
--rec_model_dir="./inference_model/rec"识别结果会可视化在图像中并保存在./inference_results文件夹下,可以通过左边的目录结构选择要打开的文件, 也可以通过如下代码将可视化后的图像显示出来,观察OCR文本识别的效果。
在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显着改善模型性能,但这又带来了模型复杂度请求提升的问题。大模型在实际场景中使用的成本首先。
深度神经网络一般有局部的参数再现,目前有几种主要的方法对模型进行压缩,减小其参数量。如教师量化、知识调整等,其中知识调整是指使用模型(教师模型) )去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升。
另外,在知识补充任务中,也衍生出了互训学习的模型训练方法,论文深度互助中指出,使用两个模型在训练的过程中互相监督,可以达到比单个模型更好的效果。
在知识增量训练的过程中,数据结构、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后期处理、指标计算等模块的配置文件需要进行调整。
下面以识别与检测的知识配置配置文件为例,对知识配置的训练与配置进行解析。
配置文件在ch_PP-OCRv3_rec_distillation.yml。
知识补充任务中,模型结构配置如下所示
Architecture:
model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的模型类别
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
algorithm: Distillation # 算法名称
Models: # 模型,包含子网络的配置信息
Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
pretrained: # 该子网络是否需要加载预训练模型
freeze_params: false # 是否需要固定参数
return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
model_type: *model_type # 模型类别
algorithm: SVTR # 子网络的算法名称,该子网络其余参数均为构造参数,与普通的模型训练配置一致
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2] last_pool_type: avg
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden_dims: 120
use_guide: True
Head:
fc_decay: 0.00001
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2] last_pool_type: avg
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden_dims: 120
use_guide: True
Head:
fc_decay: 0.00001
- SARHead:
enc_dim: 512
max_text_length: *max_text_length当然,这里如果希望添加更多的子网络进行训练,也可以按照Student与Teacher的添加方式,在配置文件中添加相应的字段。比如如果希望有3个模型互相监督,共同训练可以,那么Architecture写为如上格式。
最终该模型训练时,包含3个子网络:Teacher, Student, Student2。
补充模型DistillationModel类的具体实现代码可以参考distillation_model.py。
最终模型forward输出为一个字典,key为所有的子网络名称,例如这里为Student与Teacher,value为对应子网络的输出,可以为Tensor(只返回该网络的最后一层)和dict(也返回了中间的特征信息) 。
在每个识别任务中,为了添加更多的损失函数,保证后续方法的可扩展性,将子网络的输出保存为dict,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果结果dict,key包含backbone_out, neck_out, head_out,value为模块配置的张量,最终对于上述对应文件,DistillationModel的输出格式如下。
{ "Teacher": { "backbone_out": tensor, "neck_out": tensor, "head_out": tensor,
}, "Student": { "backbone_out": tensor, "neck_out": tensor, "head_out": tensor,
}
}知识补充任务中,损失函数配置如下所示。
Loss: name: CombinedLoss loss_config_list:
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss weight: 1.0 # 权重 act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None use_log: true # 对输入计算log,如果函数已经 model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
- ["Student", "Teacher"] key: head_out # 取子网络输出dict中,该key对应的tensor multi_head: True # 是否为多头结构 dis_head: ctc # 指定用于计算损失函数的head name: dml_ctc # 蒸馏loss的前缀名称,避免不同loss之间的命名冲突
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss weight: 0.5 # 权重 act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None use_log: true # 对输入计算log,如果函数已经 model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
- ["Student", "Teacher"] key: head_out # 取子网络输出dict中,该key对应的tensor multi_head: True # 是否为多头结构 dis_head: sar # 指定用于计算损失函数的head name: dml_sar # 蒸馏loss的前缀名称,避免不同loss之间的命名冲突
- DistillationDistanceLoss: # 蒸馏的距离损失函数 weight: 1.0 # 权重 mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1 model_name_pairs: # 用于计算distance loss的子网络名称对
- ["Student", "Teacher"] key: backbone_out # 取子网络输出dict中,该key对应的tensor
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段 model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss key: head_out # 取子网络输出dict中,该key对应的tensor
- DistillationSARLoss: # 基于蒸馏的SAR损失函数,继承自标准的SARLoss weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段 model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss key: head_out # 取子网络输出dict中,该key对应的tensor multi_head: True # 是否为多头结构,为true时,取出其中的SAR分支计算损失函数结果损失函数中,所有的后续损失函数均继承自标准的损失函数类,主要功能为:对后续模型的输出进行解析,找到计算损失的中间节点(张量),再使用标准的损失函数类去计算。
以上述配置为例,最终操作的损失函数包含下面5个部分。
关于CombinedLoss更具体的实现可以参考:combined_loss.py。关于DistillationCTCLoss等损失函数更具体的实现可以参考distillation_loss.py。
知识补充任务中,后续处理配置如下所示。
PostProcess: name: DistillationCTCLabelDecode # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类 model_name: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码 key: head_out # 取子网络输出dict中,该key对应的tensor multi_head: True # 多头结构时,会取出其中的CTC分支进行计算
以上述配置为例,最终会同时计算Student和Teahcer2个子网的CTC解码输出,返回一个dict,key为用于处理的子网名称,value为用于处理的子网列表。
关于DistillationCTCLabelDecode更具体的实现可以参考:rec_postprocess.py
知识补充任务中,指标计算配置如下所示。
Metric: name: DistillationMetric # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类 base_metric_name: RecMetric # 指标计算的基类,对于模型的输出,会基于该类,计算指标 main_indicator: acc # 指标的名称 key: "Student" # 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准 ignore_space: False # 评估时是否忽略空格的影响
以上述配置为例,最终会使用Student子网络的acc指标作为保存最佳模型的判断指标,同时,日志中基因打印出所有子网络的acc指标。
关于DistillationMetric更具体的实现可以参考:distillation_metric.py。 添加模型配置 对获得的识别进行配制有2种方式。
(1)基于知识配置的账户路径:这种情况比较简单,下载预训练模型,在ch_PP-OCRv3_rec_distillation.yml中配置好预训练模型路径以及自己的数据路径,即可进行模型账户训练。
(2)模型时不使用知识增量:这种情况,需要首先将预训练模型中的学生模型参数提取出来,
!python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
然后使用python,对其中的学生模型参数进行提取
转化完成之后,使用ch_PP-OCRv3_rec.yml,修改预训练模型的路径(为导出的student.pdparams模型路径)以及自己的数据路径,即可进行模型模型。
import paddle# 加载预训练模型all_params = paddle.load("output/rec_ppocr_v3_distillation/best_accuracy.pdparams")# 查看权重参数的keysprint(all_params.keys())# 学生模型的权重提取s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}# 查看学生模型权重参数的keysprint(s_params.keys())# 保存paddle.save(s_params, "output/rec_ppocr_v3_distillation/student.pdparams")检测模型的补充配置文件在PaddleOCR/configs/det/ch_PP-OCRv3/目录下,包含两个补充配置文件:
检测模型的补充配置文件在PaddleOCR/configs/det/ch_PP-OCRv3/目录下,包含两个补充配置文件:
ch_PP-OCRv3_det_cml.yml,采用cml移植,采用一个大模型移植两个小模型,且两个小模型互相学习的方法 ch_PP-OCRv3_det_dml.yml,采用DML的补充,两个Student模型相互补充的方法
知识补充任务中,模型结构配置如下所示:
Architecture:
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
algorithm: Distillation # 算法名称
Models: # 模型,包含子网络的配置信息
Student: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
freeze_params: false # 是否需要固定参数
return_all_feats: false # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
model_type: det
algorithm: DB
Backbone:
name: ResNet
in_channels: 3
layers: 50
Neck:
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2] k: 50
Teacher: # 另外一个子网络,这里给的是DML蒸馏示例,
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
in_channels: 3
layers: 50
Neck:
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2] k: 50如果是采用DML,即两个小模型互相学习的方法,上述配置文件里的教师网络结构需要设置为学生模型一样的配置,具体参考配置文件ch_PP-OCRv3_det_dml.yml。
上面介绍ch_PP-OCRv3_det_cml.yml的配置文件参数
补充模型DistillationModel类的具体实现代码可以参考distillation_model.py。
最终模型forward输出为一个字典,key为所有的子网络名称,例如这里为Student与Teacher,value为对应子网络的输出,可以为Tensor(只返回该网络的最后一层)和dict(也返回了中间的特征信息) 。
在执行任务中,为了方便添加附加损失函数,每个网络的输出保存为dict,其中包含子模块输出。每个子网络的输出结果结果dict,key包含backbone_out,,,对应neck_out,为模块的张量head_out,value最终用于配置上述文件,DistillationModel输出格式如下。
{ "Teacher": { "backbone_out": tensor, "neck_out": tensor, "head_out": tensor,
}, "Student": { "backbone_out": tensor, "neck_out": tensor, "head_out": tensor,
}
}检测到 ch_PP-OCRv3_det_cml.yml 添加损失函数配置如下所示。
Loss: name: CombinedLoss loss_config_list:
- DistillationDilaDBLoss: weight: 1.0 model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"] # 改动1,计算两个Student和Teacher的损失 key: maps balance_loss: true main_loss_type: DiceLoss alpha: 5 beta: 10 ohem_ratio: 3
- DistillationDMLLoss: # 改动2,增加计算两个Student之间的损失 model_name_pairs:
- ["Student", "Student2"] maps_name: "thrink_maps" weight: 1.0
# act: None key: maps
- DistillationDBLoss: weight: 1.0 model_name_list: ["Student", "Student2"] # 改动3,计算两个Student和GT之间的损失 balance_loss: true main_loss_type: DiceLoss alpha: 5 beta: 10 ohem_ratio: 3关于DistillationDilaDBLoss更具体的实现可以参考:distillation_loss.py。关于DistillationDBLoss等损失函数更具体的实现可以参考distillation_loss.py。
知识补充任务中,检测补充后续处理配置如下所示
PostProcess: name: DistillationDBPostProcess # DB检测蒸馏任务的CTC解码后处理,继承自标准的DBPostProcess类 model_name: ["Student", "Student2", "Teacher"] # 对于蒸馏模型的预测结果,提取多个子网络的输出,进行解码,不需要后处理的网络可以不在model_name中设置 thresh: 0.3 box_thresh: 0.6 max_candidates: 1000 unclip_ratio: 1.5
知识增量任务中,检测增量指标计算配置如下所示。
Metric: name: DistillationMetric base_metric_name: DetMetric main_indicator: hmean key: "Student"
由于需要包含多个网络,甚至多个学生网络,在计算指标的时候只需要计算一个学生网络的指标即可,key字段设置为Student则表示只计算Student网络的精度。
PP-OCRv3检测补充有两种方式:
在精度提升方面,cml的精度>dml的精度调整方法的精度。当数据量不足或者教师模型精度与学生精度相差不大的时候,这个结论可能会改变。
另外,由于PaddleOCR提供的补充预训练模型包含了多个模型的参数,如果您希望提取学生模型的参数,可以参考以下代码:
!python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
最终Student模型的参数将保存在ch_PP-OCRv3_det_distill_train/student.pdparams中,用于模型的微调。
!python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml
import paddle# 加载预训练模型all_params = paddle.load("output/ch_db_mv3/latest.pdparams")# 查看权重参数的keysprint(all_params.keys())# 学生模型的权重提取s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}# 查看学生模型权重参数的keysprint(s_params.keys())# 保存paddle.save(s_params, "output/ch_db_mv3/student.pdparams")
以上就是2024年最新版PaddleOCR新手指导(训练自己的数据集与知识蒸馏)的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号