在使用tensorflow tf-agents库构建强化学习dqn代理时,开发者可能会遇到一个特定的运行时错误,尤其是在调用代理的探索策略(agent.collect_policy.action(time_step))时。错误信息通常如下所示:
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __wrapped__Select_device_/job:localhost/replica:0/task:0/device:CPU:0}} 'then' and 'else' must have the same size. but received: [1] vs. [] [Op:Select] name:
值得注意的是,通常情况下,调用代理的标准策略(agent.policy.action(time_step))可能不会触发此错误。这表明问题可能与collect_policy内部的特定逻辑(例如,探索机制,如epsilon-greedy策略)有关,而不仅仅是TimeStep与TimeStepSpec的通用匹配问题。
该错误信息明确指出,TensorFlow内部的Select操作(对应于Python中的tf.where)在比较其then和else分支的张量大小时发现不一致。具体来说,它接收到一个形状为[1]的张量和一个形状为[](即标量)的张量,导致操作失败。
tf_agents库在定义环境和代理的交互接口时,严格依赖于TimeStepSpec和ActionSpec来描述期望的张量结构。TimeStepSpec定义了每个时间步(TimeStep)中各个组件(如step_type、reward、discount、observation)的预期形状、数据类型和取值范围。
InvalidArgumentError的根本原因在于TimeStepSpec中对标量组件的形状定义与collect_policy内部处理这些组件时的预期形状不一致。
问题就出在这里:如果TimeStepSpec将reward、discount、step_type等定义为shape=(1,)(意图表示“一个批次中有一个元素”),而collect_policy内部(特别是像epsilon_greedy_policy这样的策略,它可能在内部对单个元素执行tf.where操作)却期望这些组件的元素本身是标量(即shape=()),那么就会发生冲突。tf.where操作会尝试将一个[1]形状的张量(来自TimeStepSpec中shape=(1,)的假设)与一个[]形状的张量(来自策略内部对标量的处理)进行比较,从而抛出InvalidArgumentError。
解决此问题的关键在于确保TimeStepSpec中对标量组件的形状定义是正确的,即使用shape=()。tf_agents的策略会自动处理输入TimeStep中的批次维度。
在原始问题中,TimeStepSpec的定义可能如下所示,其中step_type、reward、discount的shape被错误地设置为(1,):
import tensorflow as tf from tf_agents.specs import tensor_spec from tf_agents.trajectories.time_step import TimeStep # ... 其他定义,如amountMachines ... # 错误的 TimeStepSpec 定义 time_step_spec = TimeStep( step_type=tensor_spec.BoundedTensorSpec(shape=(1,), dtype=tf.int32, minimum=0, maximum=2), reward=tensor_spec.TensorSpec(shape=(1,), dtype=tf.float32), discount=tensor_spec.TensorSpec(shape=(1,), dtype=tf.float32), observation=tensor_spec.TensorSpec(shape=(1, amountMachines), dtype=tf.int32) )
对于step_type、reward和discount这些本质上是标量的组件,它们的TensorSpec形状应该定义为(),表示它们是0维张量。
import tensorflow as tf from tf_agents.specs import tensor_spec from tf_agents.trajectories.time_step import TimeStep from tf_agents.agents.dqn import dqn_agent from tf_agents.utils import common # 假设 amountMachines 和 model 已定义 amountMachines = 6 # 示例值 # model = ... # 您的 Q 网络模型 # train_step_counter = tf.Variable(0) # 训练步数计数器 # learning_rate = 1e-3 # 学习率 # 正确的 TimeStepSpec 定义 time_step_spec = TimeStep( step_type=tensor_spec.BoundedTensorSpec(shape=(), dtype=tf.int32, minimum=0, maximum=2), reward=tensor_spec.TensorSpec(shape=(), dtype=tf.float32), discount=tensor_spec.TensorSpec(shape=(), dtype=tf.float32), observation=tensor_spec.TensorSpec(shape=(1, amountMachines), dtype=tf.int32) ) # 动作空间定义(保持不变) num_possible_actions = 729 action_spec = tensor_spec.BoundedTensorSpec( shape=(), dtype=tf.int32, minimum=0, maximum=num_possible_actions - 1) # 代理初始化(保持不变) # optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) # agent = dqn_agent.DqnAgent( # time_step_spec, # action_spec, # q_network=model, # optimizer=optimizer, # epsilon_greedy=1.0, # td_errors_loss_fn=common.element_wise_squared_loss, # train_step_counter=train_step_counter) # agent.initialize()
即使TimeStepSpec中这些组件的形状是(),在创建实际的TimeStep实例时,由于通常会处理批次数据(即使批次大小为1),我们仍然需要将标量值包装成一个包含单个元素的张量。例如,tf.convert_to_tensor([value], dtype=...)会创建一个形状为(1,)的张量,这对于批次大小为1的情况是正确的。tf_agents的策略会正确地处理这种批次维度。
# 假设 get_states() 返回一个 NumPy 数组,例如 [4,4,4,4,4,6] # 假设 step_type, reward, discount 也是单个数值 current_state = tf.constant([4,4,4,4,4,6], dtype=tf.int32) # 示例状态 current_state_batch = tf.expand_dims(current_state, axis=0) # 形状变为 (1, 6) step_type_val = 0 # 示例值 reward_val = 0.0 # 示例值 discount_val = 0.95 # 示例值 # TimeStep 数据的创建方式(保持不变) # 注意:即使 TimeStepSpec 中 shape=(),这里仍然创建形状为 (1,) 的张量 time_step = TimeStep( step_type=tf.convert_to_tensor([step_type_val], dtype=tf.int32), reward=tf.convert_to_tensor([reward_val], dtype=tf.float32), discount=tf.convert_to_tensor([discount_val], dtype=tf.float32), observation=current_state_batch ) # 调用 collect_policy (现在应该正常工作) # action_step = agent.collect_policy.action(time_step)
通过遵循这些最佳实践,可以有效避免TF-Agents中常见的形状不匹配问题,确保强化学习代理的训练和执行流程顺畅。
以上就是TensorFlow TF-Agents DQN collect_policy InvalidArgumentError: 解决 then 和 else 尺寸不匹配问题的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号