Breakout DQN : Agent avec Double DQN

Suite du sujet Breakout DQN : prétraitement et Experience Replay :

La classe DQNAgent

L’agent est composé de 2 DQN

from collections import deque
import breakout_dqn_oh
import numpy as np

class DQNAgent:
    def __init__(self):
        self.DQN_online=breakout_dqn_oh.BreakOutDQN()#celui qui joue
        self.DQN_target=breakout_dqn_oh.BreakOutDQN()#celui qui estime Q(s,a)
        self.Online2Target()#on initialise avec les même poids

Choix de l’action par l’agent

Il délègue simplement cette tâche au DQN primaire et récupère aussi la valeur q de l’action : cette dernière n’est pas obligatoire mais très utile pour vérifier qu’elle n’explose pas durant l’apprentissage.

def GetAction(self,state):
    (qaction, action)=self.DQN_online.GetAction(state)
    return (qaction, action)

Transformation de l’action choisie en format one-hot

def int2oh(action,nbr_actions):
    # supposons a=1
    action_oh=np.zeros(nbr_actions)# 0 0 0 0
    action_oh[action]=1# 0 1 0 0
    return action_oh

Apprentissage de l’agent

Version 1 : Target Network

Rappelons la formule d’apprentissage:

Q(s,a)=r+\gamma \max_{a'}{Q(s',a')}\\

Dans l’approche Target Network, on utilise le second DQN pour estimer Q(s',a')

def learning(self,gamma,minibatch,batch_size):
    states,actions,states_prime, rewards, game_over = map(list,zip(*minibatch))
    #D'abord on estime les valeurs de Qmax(s',a') avec le target DQN
    Q_values_prime=self.DQN_target.GetQValues(states_prime,game_over,batch_size)
    #Ensuite on estime la qualité de l'état de départ via la formule du Qlearning
    actions_oh=np.zeros((batch_size,4))
    q_oh=np.zeros((batch_size,4))
    for i in range(0,batch_size):
        #formule du q-learning
        Q_values= rewards[i] + gamma * Q_values_prime[i]
        #on les transforme en format one-hot
        action=actions[i]
        actions_oh[i]=int2oh(action,4)
        q_oh[i]=actions_oh[i]*Q_values
    #Maintenant on entraine le DQN primaire
    self.DQN_primary.SetWeights(states, actions_oh, q_oh, batch_size)

Version 2 : Double DQN (Target DQN + Double Q-learning)

On a vu précédemment que l’algorithme du Q-learning, à cause de sa fonction max, a tendance à surestimer les valeurs Q des actions. Une méthode appelée Double Q-learning a été utilisée pour corriger ce problème : elle nécessite 2 fonctions Q indépendantes.
On a aussi vu que l’algorithme du Deep Q-learning nécessitait un 2ème DQN, appelé DQN cible (Target) pour être stable.

L’approche du Double DQN est simplement d’utiliser notre second DQN dans la méthode du Double Q-learning :
-le DQN primaire choisit l’action
-le DQN secondaire calcule la valeurs Q de cette action

On a juste à remplacer:

#D'abord on estime les valeurs de Qmax(s',a') avec le target DQN
Q_values_prime=self.DQN_target.GetQValues(states_prime,game_over,batch_size)

par

#D'abord on estime les valeurs de Qmax(s',a') avec l'algorithme du double dqn
Q_values_prime=self.Double_Qlearning(states_prime,game_over,batch_size)
def Double_Qlearning(self,states_prime,game_over,batch_size):
        #le DQN primaire choisi l'action
        actions=self.DQN_online.GetActions(states_prime,batch_size)
        #le DQN secondaire calcule les valeurs Q
        q_values_target=self.DQN_target.GetQValues(states_prime, game_over,batch_size)
        #on remplace les valeurs q des actions du 1er DQN par celles du 2ème
        q_values=np.zeros(batch_size)
        for i in range(0,batch_size):
            id_action=int(actions[i])
            q_values[i]=q_values_slave[i][id_action]
        return q_values

Mise à jour du Target DQN

On a vu dans le code d’entraînements que le 1er DQN devait recopier ses poids dans le 2ème DQN à une période définie. Pour ce faire, j’ai simplement sauvegardé/rechargé les poids par le biais d’un fichier.

def Online2Target(self):
    self.DQN_primary.SaveWeights("dqn-model")
    self.DQN_target.LoadWeights("dqn-model")

Sauvegarde/Chargement de notre Agent

def SaveAgent(self,filename):
        self.DQN_primary.SaveWeights(filename)

def LoadAgent(self,filename):
        self.DQN_primary.LoadWeights(filename)

Dernière ligne droite

On va maintenant s’attaquer à la classe de notre DQN.