我目前用的DDPG算法预测股价,请问action怎么写才不报错

问题遇到的现象和发生背景
问题相关代码,请勿粘贴截图

def dqn(n_episodes=EPISODE_COUNT, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
scores = []
scores_window = deque(maxlen=100)
eps = eps_start
for i_episode in range(1, n_episodes+1):
print("Episode" + str(i_episode))
state = getState(stockData, 0, STATE_SIZE + 1)
pos_old = 0 #初始化持仓
money_initial = 10000 # 初始化资金
money = money_initial
cost = 0 # 初始化等效成本
# total_profit = 0
total_share = 0 # 初始化持股
agent.inventory = []

    reward = 0
    for t in range(l):
        action = agent.act(state, eps)
        next_state = getState(stockData, t + 1, STATE_SIZE + 1)
        # reward = 0

        if action == 1 :# 加仓20%
            #agent.inventory.append(stockData[t])
            #print("buy" + str(stockData[t]))
            pos_new = min(pos_old + 0.2, 1)
            total_share += money * (pos_new - pos_old) / stockData[t]



        elif action == 2:
            # 减仓20%
            #bought_price = agent.inventory.pop(0)
            pos_new = max(pos_old - 0.2, 0)
            total_share += money * (pos_new - pos_old) / stockData[t]


            # reward = max(stockData[t] - bought_price, 0)
            # reward = stockData[t] - cost
            # print("Sell: " + str(stockData[t]) + " | Profit: " + str(stockData[t] - bought_price))
        else: # 持仓
            pos_new = pos_old


        # cost = cost_calculate(stockData[t], money, pos_new, total_share)
        money = money_calculate(money, total_share, stockData[t], pos_new)
        if money < 0 or t == l - 1:
            done = 1
        else:
            done = 0
        reward = (money - money_initial) / money_initial
        agent.step(state, action, reward, next_state, done)
        eps = max(eps_end, eps * eps_decay)
        state = next_state
        pos_old = pos_new


        if done:
            print("------------------------------")
            print("total_profit = " + str((money - money_initial) / money_initial))
            print("------------------------------")
            break
    scores.append((money - money_initial) / money_initial)
    scores_window.append((money - money_initial) / money_initial)
    if np.mean(scores_window) > 0.2 and len(scores_window) == 100:
        torch.save(agent.actor_local.state_dict(), 'checkpoint_actor.pth')
        torch.save(agent.critic_local.state_dict(), 'checkpoint_critic.pth')
        break

torch.save(agent.actor_local.state_dict(), 'checkpoint_actor.pth')
torch.save(agent.critic_local.state_dict(), 'checkpoint_critic.pth')
return scores
运行结果及报错内容

Episode1
C:\Users\22536.conda\envs\torch\lib\site-packages\torch\nn\functional.py:1795: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
Traceback (most recent call last):
File "D:/Git/stockPrediction2/main_new1.py", line 124, in
scores = dqn()
File "D:/Git/stockPrediction2/main_new1.py", line 45, in dqn
if action == 1 :# 加仓20%
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Process finished with exit code 1

img

我的解答思路和尝试过的方法
我想要达到的结果

值错误:包含一个以上元素的数组的真值是不明确的,要使用a.any()或a.all()
if action改成if action.all()再运行看看