DQN目标模型权重更新失败

DQN的目标网络模型无法更新权重,不知道是怎么回事了。
求万能的网友解答,谢谢!
下面是DQN的代码:

import numpy as np
import tensorflow as tf
from tensorflow import keras


# 创建经验池
class ExperienceReplayBuffer:
    def __init__(self, capacity=20000):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
        self.length = len(self.buffer)

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        # batch = random.sample(self.buffer, batch_size)
        batch = np.random.choice(self.buffer, batch_size, replace=False)
        states, action, rewards, next_states, dones = zip(*batch)
        return np.array(states), np.array(action), np.array(rewards, dtype=np.float32), np.array(
            next_states), np.array(dones)

    def __len__(self):
        # return len(self.buffer)
        return self.length


# 创建DQN模型
class DQNModel(tf.keras.Model):
    def __init__(self, num_action):
        super(DQNModel, self).__init__()
        self.dense1 = keras.layers.Dense(256, kernel_regularizer=keras.regularizers.L2(0.001), activation='relu')
        self.dense2 = keras.layers.Dense(64, kernel_regularizer=keras.regularizers.L2(0.001), activation='relu')
        self.dense3 = keras.layers.Dense(num_action, kernel_regularizer=keras.regularizers.L2(0.001), activation='linear')

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        outputs = self.dense3(x)
        return outputs


# 创建DQN Agent
class DQNAgent:
    def __init__(self, num_action):
        self.num_actions = num_action
        self.model = DQNModel(num_action)
        self.target_model = DQNModel(num_action)
        self.target_model.set_weights(weights=self.model.get_weights())
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
        self.loss_function = tf.keras.losses.MeanSquaredError()

    def update_target_model(self):
        self.target_model.set_weights(self.model.get_weights())
        # pass

    def get_action(self, state, epsilon):
        if np.random.rand() <= epsilon:
            return np.random.choice([-1, 1])
        else:
            q_values = self.model.predict(state)
            return np.argmax(q_values[0])

    def train(self, replay_buffer, batch_size, discount_factor):
        if len(replay_buffer) < batch_size:
            return
        states, action, rewards, next_states, dones = replay_buffer.sample(batch_size)
        next_q_values = self.target_model.predict(next_states)
        max_next_q_values = np.max(next_q_values, axis=1)
        target_q_values = rewards + (1 - dones) * discount_factor * max_next_q_values
        mask = tf.one_hot(action, self.num_actions)
        with tf.GradientTape() as tape:
            q_values = self.model(states)
            q_action = tf.reduce_sum(tf.multiply(q_values, mask), axis=1)
            loss = self.loss_function(target_q_values, q_action)
        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

报错是:You called set_weights(weights) on layer "dqn_model_1" with a weight list of length 6, but the layer was expecting 0 weights. Provided weights:[......]

【以下回答由 GPT 生成】

解决方案:

从提供的代码中可以看出,目标网络模型的权重没有进行更新。在DQNAgent类中,虽然定义了两个模型(self.model和self.target_model),但是没有定义更新目标网络权重的逻辑。

要解决这个问题,需要添加一些代码来更新目标网络模型的权重。

以下是一个可能的解决方案:

首先,在DQNAgent类中添加一个方法来更新目标网络模型的权重:

class DQNAgent:
    def __init__(self, num_action):
        self.num_actions = num_action
        self.model = DQNModel(num_action)
        self.target_model = DQNModel(num_action)  # 添加目标网络模型
        self.update_target_model()  # 初始化时更新目标网络模型的权重

    def update_target_model(self):
        self.target_model.set_weights(self.model.get_weights())

然后,在训练循环中,每隔一定时间更新目标网络模型的权重:

def train_dqn(agent, replay_buffer, target_update_interval=100):
    # 训练循环
    for episode in range(num_episodes):
        for step in range(max_steps_per_episode):
            # 执行动作、观察环境、更新经验池等
            # ...

            # 每隔一定时间更新目标网络模型的权重
            if step % target_update_interval == 0:
                agent.update_target_model()

            # 使用经验池中的样本进行训练、更新模型等
            # ...

通过添加以上代码,可以在训练循环中定期更新目标网络模型的权重,以确保它与主网络模型保持一致。

希望以上解决方案对你有帮助!如果问题仍然存在,请提供更多相关代码和错误信息,以便我能够更进一步地帮助你解决问题。


如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^