为PaddleRS添加一个袖珍配置系统

P粉084495128
发布: 2025-08-01 14:24:32
原创
1018人浏览过
本文为PaddleRS设计轻量级配置系统,支持yaml文件与命令行选项配置,实现单继承。通过CfgNode抽象配置项,完成配置文件与命令行相互转换,以外部脚本封装PaddleRS。经变化检测等任务检验,可实现模型训练和推理,后续可优化配置编写与动态作用域。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

为paddlers添加一个袖珍配置系统 - php中文网

为PaddleRS添加一个袖珍配置系统

1 项目动机


PaddleRS提供了许多便捷的API,通过编写脚本、简单组合这些API,只需要数行代码就可以实现模型的训练和推理。然而,一方面,有些同学可能已经习惯了PaddleSeg、PaddleDetection等套件基于配置文件的“零代码”使用方式,对需要自己编写代码的API调用方式不太熟悉;另一方面,在科研或项目过程中,基于配置文件的版本控制有助于减轻实验负担。为此,本项目尝试为PaddleRS添加一个轻量级、非侵入式、可继承的配置系统,支持yaml文件与命令行选项两种配置方式。

2 设计方案


如上节所述,本项目旨在为PaddleRS添加一个轻量级、非侵入式、可继承的配置系统,且同时支持配置文件和命令行选项两种配置方式。为了实现上述功能,本项目做出如下设计方案:

  1. 精简功能,代码行数控制在500行以内,尽可能保证逻辑简单,不做不必要的抽象,不过度设计;
  2. 以外部脚本的形式对PaddleRS进行封装,从而做到不修改PaddleRS原有代码;
  3. 考虑到轻量性,牺牲多继承功能,仅实现配置文件的单继承;
  4. 为了让代码更加易用,在配置文件继承方面仿照PaddleSeg的风格,即,使用_base_选项指定要继承的文件;
  5. 实现配置文件和命令行选项之间的相互转换(最好能无失真)。

3 代码实现


3.1 对配置项的抽象

本着“避免过度设计”的原则,本项目以键值对方式组织配置项,配置项的键均为字符串,而值的类型一共有四种:

  • 标量:如1145.14、'homo'等;
  • 字典:其中包含数个配置项(键值对);
  • 列表:其中顺序包含数个值;
  • CfgNode:可以构造为特殊对象(如模型、数据集、优化器等),具有type、args、module三个属性。

CfgNode的存在是为了方便从配置文件中提取出一些“特殊”的项,用于后续构造“目标对象”。CfgNode的type指的是目标对象的类型名;args以列表或字典的形式包含构造目标对象时的输入参数;而module则是一个模块名,其中包含type指定的类。

CfgNode的完整定义如下:

class _CfgNodeMeta(yaml.YAMLObjectMetaclass):
    def __call__(cls, obj):
        if isinstance(obj, CfgNode):            return obj        return super(_CfgNodeMeta, cls).__call__(obj)class CfgNode(yaml.YAMLObject, metaclass=_CfgNodeMeta):
    yaml_tag = u'!Node'
    yaml_loader = yaml.SafeLoader    # By default use a lexical scope
    ctx = globals()    def __init__(self, dict_):
        super().__init__()
        self.type = dict_['type']
        self.args = dict_.get('args', [])
        self.module = self._get_module(dict_.get('module', ''))    @classmethod
    def set_context(cls, ctx):
        # TODO: Implement dynamic scope with inspect.stack()
        old_ctx = cls.ctx
        cls.ctx = ctx        return old_ctx    def build_object(self, mod=None):
        if mod is None:
            mod = self.module
        cls = getattr(mod, self.type)        if isinstance(self.args, list):
            args = build_objects(self.args)
            obj = cls(*args)        elif isinstance(self.args, dict):
            args = build_objects(self.args)
            obj = cls(**args)        else:            raise NotImplementedError        return obj    def _get_module(self, s):
        mod = None
        while s:
            idx = s.find('.')            if idx == -1:
                next_ = s
                s = ''
            else:
                next_ = s[:idx]
                s = s[idx+1:]            if mod is None:
                mod = self.ctx[next_]            else:
                mod = getattr(mod, next_)        return mod    @staticmethod
    def build_objects(cfg, mod=None):
        if isinstance(cfg, list):            return [CfgNode.build_objects(c, mod=mod) for c in cfg]        elif isinstance(cfg, CfgNode):            return cfg.build_object(mod=mod)        elif isinstance(cfg, dict):            return {k: CfgNode.build_objects(v, mod=mod) for k, v in cfg.items()}        else:            return cfg    def __repr__(self):
        return f"(type={self.type}, args={self.args}, module={self.module or ' '})"    @classmethod
    def from_yaml(cls, loader, node):
        map_ = loader.construct_mapping(node)        return cls(map_)    def items(self):
        yield from [('type', self.type), ('args', self.args), ('module', self.module)]    def to_dict(self):
        return dict(self.items())
登录后复制
       

首先可以注意到CfgNode的父类为yaml.YAMLObject,元类为_CfgNodeMeta。继承自yaml.YAMLObject,同时定义yaml_tag、yaml_loader以及from_yaml方法,可以让pyyaml自动从配置文件中解析出CfgNode对象。_CfgNodeMeta元类用于实现这样的功能:当输入参数为CfgNode对象时,不重新构造一个新的实例,而是直接返回输入的对象。

items()方法返回一个生成器,在部分场合下使CfgNode对象可以模拟Mapping类型的行为。to_dict()方法则将CfgNode对象转换为一个dict。

实例方法build_object()和静态方法build_objects()分别用于单个CfgNode对象和包含CfgNode对象的容器对目标对象的构造。由于self.args中可能也包含有CfgNode对象,因此需要进行递归构造。这块目前的性能不是很高,后续可以考虑优化。

类方法set_context()用于设置self.ctx,而后者则在_get_module()方法中被用于检索包含type类的模块。set_context()方法存在的原因是为了更好的解耦,归根到底是因为Python使用lexical scope,而我们并不希望在编写配置系统逻辑的部分添加任何的外部因素——例如import paddlers。当然,用set_context()方法手工设置上下文(甚至没有用到上下文管理器!)的方式十分暴力,后续可以考虑基于inspect.stack()实现一个具有dynamic scope的_get_module()方法。

3.2 实现配置文件和命令行选项的相互转换

考虑到配置文件是从命令行指定的,因此要想读取配置文件的内容,不得不先解析一次命令行选项。关于这一步,我用到一个小技巧,即使用Python标准库的argparse.ArgumentParser对象的parse_known_args()方法配合其构造函数的conflict_handler选项。具体逻辑如下:

    cfg_parser = argparse.ArgumentParser(add_help=False)
    cfg_parser.add_argument('--config', type=str, default='')
    cfg_parser.add_argument('--inherit_off', action='store_true')
    cfg_args = cfg_parser.parse_known_args()[0]
    cfg_path = cfg_args.config
    inherit_on = not cfg_args.inherit_off    # Main parser
    parser = argparse.ArgumentParser(conflict_handler='resolve', parents=[cfg_parser])    # Global settings
    parser.add_argument('cmd', choices=['train', 'eval'])    # Data
    parser.add_argument('--datasets', type=dict, default={})
    parser.add_argument('--transforms', type=dict, default={})    # 其它各种选项...
登录后复制
       

从配置文件读取的信息到命令行选项的转换如下:

def _cfg2args(cfg, parser, prefix=''):
    node_keys = set()    for k, v in cfg.items():
        opt = prefix+k        if isinstance(v, list):            if len(v) == 0:
                parser.add_argument('--'+opt, type=object, nargs='*', default=v)            else:                # Only apply to homogeneous lists
                if isinstance(v[0], CfgNode):
                    node_keys.add(opt)
                parser.add_argument('--'+opt, type=type(v[0]), nargs='*', default=v)        elif isinstance(v, dict):            # Recursively parse a dict
            _, new_node_keys = _cfg2args(v, parser, opt+'.')
            node_keys.update(new_node_keys)        elif isinstance(v, CfgNode):
            node_keys.add(opt)
            _, new_node_keys = _cfg2args(v.to_dict(), parser, opt+'.')
            node_keys.update(new_node_keys)        elif isinstance(v, bool):
            parser.add_argument('--'+opt, action='store_true', default=v)        else:
            parser.add_argument('--'+opt, type=type(v), default=v)    return parser, node_keys
登录后复制
       

其中,对CfgNode类型值的处理是比较特别的,这是因为argparse.ArgumentParser默认是无法自动解析CfgNode对象的。作为替代方案,我在这里先将CfgNode对象转换成字典,然后把原本的键名记录在node_keys中。当然,为argparse编写相关的扩展逻辑也是可以考虑的方案。经过转换之后,配置文件中的

A:
    B: 2
登录后复制
       

可以在命令行指定为--A.B 2。而

C:
    - 1
    - 2
登录后复制
       

则对应命令行选项的--C 1 2。需要特别注意的是,type为布尔型的命令行选项被设置了action='store_true',即

A: True
登录后复制
       

对应的是--A而不是--A True。

度加剪辑
度加剪辑

度加剪辑(原度咔剪辑),百度旗下AI创作工具

度加剪辑63
查看详情 度加剪辑

接下来是从命令行选项恢复键值对的转换:

def _args2cfg(cfg, args, node_keys):
    args = vars(args)    for k, v in args.items():
        pos = k.find('.')        if pos != -1:            # Iteratively parse a dict
            dict_ = cfg            while pos != -1:
                dict_.setdefault(k[:pos], {})
                dict_ = dict_[k[:pos]]
                k = k[pos+1:]
                pos = k.find('.')
            dict_[k] = v        else:
            cfg[k] = v    for k in node_keys:
        pos = k.find('.')        if pos != -1:            # Iteratively parse a dict
            dict_ = cfg            while pos != -1:
                dict_.setdefault(k[:pos], {})
                dict_ = dict_[k[:pos]]
                k = k[pos+1:]
                pos = k.find('.')
            v = dict_[k]
            dict_[k] = [CfgNode(v_) for v_ in v] if isinstance(v, list) else CfgNode(v)        else:
            v = cfg[k]
            cfg[k] = [CfgNode(v_) for v_ in v] if isinstance(v, list) else CfgNode(v)    return cfg
登录后复制
       

其中涉及两次遍历,第一次重建字典,第二次把之前在node_keys中记录的键对应的值转换为CfgNode对象(或包含CfgNode对象的容器)。

3.3 配置文件单继承的实现

这块的逻辑比较简单,从用户指定的配置文件出发,不断向上级检索_base_配置项,并导入配置项,直到检索不到_base_为止。最后将收集的各级配置融合。具体实现如下:

def _chain_maps(*maps):
    chained = dict()
    keys = set().union(*maps)    for key in keys:
        vals = [m[key] for m in maps if key in m]        if isinstance(vals[0], Mapping):
            chained[key] = _chain_maps(*vals)        else:
            chained[key] = vals[0]    return chaineddef read_config(config_path):
    with open(config_path, 'r', encoding='utf-8') as f:
        cfg = yaml.safe_load(f)    return cfg or {}def parse_configs(cfg_path, inherit=True):
    if inherit:
        cfgs = []
        cfgs.append(read_config(cfg_path))        while cfgs[-1].get('_base_'):
            base_path = cfgs[-1].pop('_base_')
            curr_dir = osp.dirname(cfg_path)
            cfgs.append(read_config(osp.normpath(osp.join(curr_dir, base_path))))        return _chain_maps(*cfgs)    else:        return read_config(cfg_path)
登录后复制
   

3.4 将配置系统接入PaddleRS

3.1-3.3小节实现的配置系统其实和PaddleRS是彼此独立的,这符合对不同模块解耦的要求,便于复用和维护。接下来将配置系统接入PaddleRS:

CfgNode.set_context(globals())

cfg = parse_args()print(format_cfg(cfg))# Automatically download dataif cfg['download_on']:
    paddlers.utils.download_and_decompress(cfg['download_url'], path=cfg['download_path'])if cfg['cmd'] == 'train':
    train_dataset = build_objects(cfg['datasets']['train'], mod=paddlers.datasets)
    train_transforms = T.Compose(build_objects(cfg['transforms']['train'], mod=T))    # XXX Late binding of transforms
    train_dataset.transforms = train_transforms
eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))# XXX Late binding of transformseval_dataset.transforms = eval_transforms

model = build_objects(cfg['model'], mod=paddlers.tasks)if cfg['losses']:
    model.losses = {        'types': build_objects(cfg['losses']['types']),        'coef': cfg['losses']['coef']
    }if cfg['cmd'] == 'train':    if cfg['optimizer']:        if len(cfg['optimizer'].args) == 0:
            cfg['optimizer'].args = {}        if not isinstance(cfg['optimizer'].args, dict):            raise TypeError        if cfg['optimizer'].args.get('parameters', None) is not None:            raise ValueError
        cfg['optimizer'].args['parameters'] = model.net.parameters()
        optimizer = build_objects(cfg['optimizer'], mod=paddle.optimizer)    else:
        optimizer = None

    model.train(
        num_epochs=cfg['num_epochs'],
        train_dataset=train_dataset,
        train_batch_size=cfg['train_batch_size'],
        eval_dataset=eval_dataset,
        optimizer=optimizer,
        save_interval_epochs=cfg['save_interval_epochs'],
        log_interval_steps=cfg['log_interval_steps'],
        save_dir=cfg['save_dir'],
        learning_rate=cfg['learning_rate'],
        early_stop=cfg['early_stop'],
        early_stop_patience=cfg['early_stop_patience'],
        use_vdl=cfg['use_vdl'],
        resume_checkpoint=cfg['resume_checkpoint'] or None,
        **cfg['train']
    )elif cfg['cmd'] == 'eval':
    state_dict = paddle.load(os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
    model.net.set_state_dict(state_dict)
    res = model.evaluate(eval_dataset)    print(res)
登录后复制
       

在调用parse_args()解析命令行参数/配置文件以前,首先调用CfgNode.set_context()将当前全局命名空间作为检索模块的上下文传给CfgNode。download_on选项控制是否由脚本自动下载和解压数据集。接下来的步骤和PaddleRS的官方tutorial中基本一致,除了根据cmd选项决定训练模型、还是只执行验证过程。

除此之外,format_cfg()函数将所有的配置项prettify为一个更易人类阅读的字符串,其具体实现如下:

def format_cfg(cfg, indent=0):
    s = ''
    if isinstance(cfg, dict):        for i, (k, v) in enumerate(sorted(cfg.items())):
            s += ' '*indent+str(k)+': '
            if isinstance(v, (dict, list, CfgNode)):
                s += '\n'+format_cfg(v, indent=indent+1)            else:
                s += str(v)            if i != len(cfg)-1:
                s += '\n'
    elif isinstance(cfg, list):        for i, v in enumerate(cfg):
            s += ' '*indent+'- '
            if isinstance(v, (dict, list, CfgNode)):
                s += '\n'+format_cfg(v, indent=indent+1)            else:
                s += str(v)            if i != len(cfg)-1:
                s += '\n'
    elif isinstance(cfg, CfgNode):
        s += ' '*indent+f"type: {cfg.type}"+'\n'
        s += ' '*indent+f"module: {cfg.module}"+'\n'
        s += ' '*indent+'args: \n'+format_cfg(cfg.args, indent+1)    return s
登录后复制
   

4 效果检验


4.1 安装依赖库

In [ ]
# 从Gitee下载PaddleRS(个人维护的镜像)# 如果目录已经存在,则不会重复下载![ ! -d "PaddleRS" ] && git clone https://gitee.com/bobholamovic/PaddleRS# 安装PaddleRS!pip install -r PaddleRS/requirements.txt
!pip install -e PaddleRS/
登录后复制
   
In [ ]
# 安装GDAL!pip install GDAL-3.4.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl
登录后复制
   

4.2 编写配置文件

以变化检测任务为例,为了充分利用配置系统的配置文件继承功能,在work/configs/_base_/目录中新建cd_base.yml文件,其中存储不同变化检测模型共有的通用配置项,如数据集路径等。部分配置项如下所示:

datasets:
    train: !Node
        type: CDDataset
        args: 
            data_dir: /home/aistudio/work/data/airchange/
            file_list: /home/aistudio/work/data/airchange/train.txt
            label_list: null
            num_workers: 0
            shuffle: True
            with_seg_labels: False
            binarize_labels: True
    eval: !Node
        type: CDDataset
        args:
            data_dir: /home/aistudio/work/data/airchange/
            file_list: /home/aistudio/work/data/airchange/eval.txt
            label_list: null
            num_workers: 0
            shuffle: False
            with_seg_labels: False
            binarize_labels: Truenum_epochs: 5train_batch_size: 4save_interval_epochs: 3
登录后复制
       

对于CfgNode节点,根据yaml语法,只需要指定!Node tag即可,除type为必选项外,args和module均为可选项。

在work/configs/cd/目录中新建bit_default.yaml文件,使用_base_选项以相对路径指定要继承的配置文件:

_base_: ../_base_/cd_base.yaml
登录后复制
       

对于model项,不指定args和module,即使用默认参数构造模型。

model: !Node
       type: BIT
登录后复制
       

work/configs/cd/bit_custom.yaml文件中指定了更多的自定义选项,例如:

losses:
    types:
        - !Node
          type: CrossEntropyLoss
          module: paddlers.models.ppseg.models
        - !Node
          type: MixedLoss
          args:
            losses:
                - !Node
                  type: CrossEntropyLoss
                  module: paddlers.models.ppseg.models
                - !Node
                  type: DiceLoss
                  module: paddlers.models.ppseg.models
            coef: [0.8, 0.2]          module: paddlers.models.ppseg.models
    coef: [1.0, 0.5]    
model: !Node
       type: BIT
       args:
           # num_classes
           - 2
           # use_mixed_loss
           - False
           # in_channels
           - 3
           # backbone
           - resnet34
登录后复制
       

其中,构造BIT使用的args使用的是数组的形式(当然,也可以用更直观的键值对的形式),yaml文件中指定的losses将最终等价于构造如下对象:

import paddlers.models.ppseg.models as ppseg
losses = {    'types': [ppseg.CrossEntropyLoss(), ppseg.MixedLoss(losses=[ppseg.CrossEntropyLoss(), ppseg.DiceLoss()], coef=[0.8, 0.2])],    'coef': [1.0, 0.5]
}
登录后复制
       

除此之外,对于场景分类、目标检测、图像分割任务,我也编写了一些示例配置文件(参数与PaddleRS的官方tutorial中一致),存放在work/configs/clas/、work/configs/det/和work/configs/seg目录中。

4.3 执行代码

In [ ]
# 切换工作路径%cd work/
登录后复制
   
In [ ]
# 变化检测任务# 解压数据集!unzip -oq -d data/airchange/ /home/aistudio/data/data77781/SZTAKI_AirChange_Benchmark.zip# BIT模型训练,使用基础配置!python configurable_paddlers.py train --config configs/cd/bit_default.yaml# BIT模型验证!python configurable_paddlers.py eval --config configs/cd/bit_default.yaml --resume_checkpoint exp/cd/bit_default/best_model/
登录后复制
   
In [ ]
# BIT模型训练,使用自定义配置,修改了优化器、损失函数和模型构造参数!python configurable_paddlers.py train --config configs/cd/bit_custom.yaml
登录后复制
   
In [ ]
# 从命令行选项修改部分配置!python configurable_paddlers.py train --config configs/cd/bit_custom.yaml --train_batch_size 2
登录后复制
   
In [ ]
# 场景分类任务# 解压数据集!unzip -oq -d data/ucmerced/ /home/aistudio/data/data51628/UCMerced_LandUse.zip# HRNet模型!python configurable_paddlers.py train --config configs/clas/hrnet.yaml
登录后复制
   
In [ ]
# 目标检测任务# PP-YOLO模型!python configurable_paddlers.py train --config configs/det/ppyolo.yaml
登录后复制
   
In [ ]
# 图像分割任务# UNet模型,使用run_with_clean_log.py筛除GDAL警告!python run_with_clean_log.py "python configurable_paddlers.py train --config configs/seg/unet.yaml"
登录后复制
   

5 后记


本项目尝试为PaddleRS添加了一个轻量级、非侵入式、可继承的袖珍配置系统,支持yaml文件与命令行选项两种配置方式。作为一个初版实现,本项目还十分粗糙,旨在抛砖引玉。后续可以考虑的改进点有很多,例如简化配置文件的编写以及为CfgNode._get_module()实现动态作用域

以上就是为PaddleRS添加一个袖珍配置系统的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号