Breakout DQN : prétraitement et Experience Replay

Suite du sujet Breakout DQN : Fichier d'entrainement :

Structure réelle de la mémoire de rejeu

Dans un précédent sujet, j’avais expliqué que l’on stockait dans la mémoire de rejeu les dernières expériences où chacune est un quintuplet (s_t,a_t,s_{t+1},r_{t},l_{t}) dans lequel :
-s_t représente l’état de départ (1 état = 4 frames)
-a_t représente l’action choisie durant l’état s_t
-s_{t+1} représente l’état d’arrivé
-r_{t} représente la récompense associée à l’état s_{t+1}
-l_{t} indique la perte d’une vie

Naïvement, on a envie d’écrire un code qui ressemblerait à ceci :

xp=(state,action,state_prime,reward,lost_a_life)#tuple de 5 éléments
ReplayMemory.add(xp)

Le problème, dans ce cas, serait que la mémoire de rejeu occuperait beaucoup plus de place que nécessaire. En effet s_{t} et s_{t+1} ont des frames communes :
state_breakout
Pour 1 million d’expériences avec la méthode actuelle, on stockerait 8 millions de frames. Si l’on considère qu’une frame fait 7ko, il faudrait alors 56 Go en RAM pour seulement les stocker.
De ce fait, la mémoire de rejeu aura cette structure :

\begin{bmatrix} \dots & f_{t-3} & f_{t-2} & f_{t-1}& f_{t} & f_{t+1}\\ \dots & a_{t-4} & a_{t-3} & a_{t-2}& a_{t-1}& a_{t}\\ \dots & r_{t-4} & r_{t-3} & r_{t-2}& r_{t-1}& r_{t} \\ \dots & l_{t-4} & l_{t-3} & l_{t-2}& l_{t-1}& l_{t}\\ \end{bmatrix}

s_{t}=(f_{t-3},f_{t-2},f_{t-1},f_{t}) et s_{t+1}=(f_{t-2},f_{t-1},f_{t},f_{t+1})

En conséquence, on remplacera le quintuplet (s_t,a_t,s_{t+1},r_{t},l_{t}) par un quadruplet (f_{t+1},a_t,r_{t},l_{t}) pour ajouter une expérience dans la mémoire de rejeu.

Prétraitement des images du jeu

Stocker les images originales prendraient trop de place en mémoire et c’est aussi plus compliqué à analyser pour notre IA. Pour cette raison, on va rétrécir l’image et transformer les couleurs en noir et blanc.

Code du prétraitement:

def preprocess_frame(frame):
    frame = np.mean(frame,axis=2)
    frame = resize(frame, (110, 84), anti_aliasing=True)
    frame = frame[17:101,:]
    frame[frame>=60]=255
    frame[frame<60]=0
    frame = frame.astype(np.uint8) #très important d'encoder en 8 bits plutôt qu'en 32 bits (4 fois moins d'espaces)
    return frame

Le prétraitement comporte 5 étapes :

  1. On calcule la moyenne des valeurs RGB pour chaque pixel, ce qui nous donne une image en niveau de gris.
  2. On redimensionne l’image en gardant les proportions, elle passe alors d’une taille de 210*160 à une taille de 110*84
  3. On ne conserve que les pixels des lignes compris entre 17 et 101 car ils n’apportent aucune information utile en dehors de cet intervalle. La nouvelle image a donc une taille de 84*84
  4. On transforme notre image en noir et blanc : tous les pixels qui ont une valeur supérieure à 60 deviendront blancs sinon ils seront noirs. J’ai choisi la valeur 60 car en considérant celles de l’image à l’étape précédente, j’ai constaté qu’elle représentait une bonne marge d’erreur entre les blancs et les noirs.
  5. On change l’encodage de l’image pour la passer en 8 bits plutôt qu’en gardant un encodage flottant.

Voici le rendu (l’étape 5 est omise):

La classe ReplayMemory

La structure de données utilisée : deque

from collections import deque
import skimage
from skimage.transform import resize
import numpy as np
from random import randint

class ReplayMemory():
def __init__(self,size):
             self.frames=deque([],maxlen=size)
             self.actions=deque([],maxlen=size)
             self.rewards=deque([],maxlen=size)
             self.game_over=deque([],maxlen=size)

Deque est une structure de données où la durée des opérations de lecture/écriture est proportionnelle à la distance aux extrémités de la liste : le pire cas concerne donc le milieu. Si on dépasse la taille spécifiée par maxlen on aura simplement un décalage à gauche/droite des données :

a=deque([1,2,3,4],maxlen=4)
a.append(5)#ajout du 5 à droite
#a=[2 3 4 5]
a.appendleft(1)#ajout du 1 à gauche
#a=[1 2 3 4]

Ajout des expériences à la mémoire de rejeu

L’émulateur utilisé a une particularité un peu embêtante, il indique qu’on a perdu une vie uniquement lorsque la balle est remise en jeu (7 frames après la vraie perte de balle) :

Par exemple, sur cette séquence de frames, la vie a été perdue à la frame 20. Cependant, l’émulateur ne nous l’indique qu’à la frame 27. Pour cette raison, on devra prendre cette donnée en compte lorsque l’on ajoutera une expérience. De même, on supprimera les frames qui n’apportent aucune information : les 6 qui précèdent celle où la perte de vie est indiquée par l’émulateur.

def delete_last_xp(self):
    self.frames.pop()
    self.actions.pop()
    self.recompenses.pop()
    self.game_over.pop()
def add_experience(self,frame,action,reward,game_over):
    frame=preprocess_frame(frame)
    reward=np.sign(reward)#les récompenses auront comme valeurs possibles 0 ou 1.
    if(game_over):
        for i in range (0,6):
            self.delete_last_xp()
        last_element=len(self.game_over)-1
        self.game_over[last_element]=True
        game_over=False
    self.frames.append(frame)
    self.actions.append(action)
    self.rewards.append(reward)
    self.game_over.append(game_over)

Si l’émulateur indique la fin d’un round alors on supprime les 6 dernières expériences et on attribue la perte de vie à la bonne expérience.
Ensuite, on ajoute la dite expérience dans tous les cas.

Récupération des états

s_{t} et s_{t+1} sont composés de 5 frames consécutives, la méthode ci-dessous prend en entrée l’indice de la dernière et renvoie les états s_{t} et s_{t+1}.

def GetStates(self,id_xp):
    frame_t5=self.frames[id_xp]
    frame_t4=self.frames[id_xp-1]
    frame_t3=self.frames[id_xp-2]
    frame_t2=self.frames[id_xp-3]
    frame_t1=self.frames[id_xp-4]
    state_primeDQN=np.stack((frame_t2,frame_t3,frame_t4,frame_t5),axis=2)
    stateDQN=np.stack((frame_t1, frame_t2,frame_t3,frame_t4),axis=2)
    return (stateDQN,state_primeDQN)

Dans notre code d’entrainement, notre agent avait besoin de récupérer le dernier état dans lequel il se trouvait pour l’utiliser dans son choix d’actions :

def GetLastState(self):
    last_id=len(self.frames)-1
    ( _ , stateDQN)=self.GetStates(last_id)
    return stateDQN

Récupération du mini-batch d’entraînements

Dans ce qui suit, 1 partie équivaut à 1 perte de vie et non pas celle de 5 vies.

Lorsqu’on récupère des expériences d’entraînements, on doit être vigilant concernant 2 aspects:
-Ne pas récupérer d’états stockés après la dernière partie complète. En effet tant que la partie n’est pas finie, il est impossible de connaitre la frame où la vie sera perdue.
-Ne pas récupérer d’états dont les frames appartiennent à 2 parties différentes.

La fonction ci-dessous prend une liste de booléens et renvoie l’id de la dernière case qui vaut VRAI (donc partie complète) :

def last_true(game_over_list):
        i=len(game_over_list)-1
        while (game_over_list[i]==False):
                i=i-1
        return i

Il s’agit maintenant de récupérer n indices permettant de recréer les états associés avec la méthode GetStates. Un indice tiré au sort doit respecter les 2 conditions suivantes:
-ne pas avoir déjà été tiré
-ne pas faire parti des 4 premières frames d’une partie

def n_random_xp(n,game_over_list,last_true):
    x=list()
    while(len(x)<n):
        a=randint(4,last_true)
        if(not((a in x) or game_over_list[a-1] or game_over_list[a-2] or game_over_list[a-3] or game_over_list[a-4])):
            x.append(a)
    return x

On peut maintenant écrire la méthode pour récupérer le mini-batch:

def GetMinibatch(self, batch_size):
    minibatch=list()
    game_over_list=list(self.game_over)
    last_id_true=last_true(game_over_list)
    indices=n_random_xp(batch_size,game_over_list,last_id_true)
    for i in indices:
        (state, state_prime)=self.GetStates(i)
        action=self.actions[i]
        reward=self.rewards[i]
        game_over=self.game_over[i]
        experience=(state, action, state_prime, reward, game_over)
        minibatch.append(experience)
    return minibatch

On pourrait encore améliorer

Il existe une technique qui s’appelle Prioritized Experience Replay (PER) qui peut aider grandement notre agent à s’améliorer. Je ne l’ai pas utilisée ici car je ne voulais pas vous perdre sous un flot d’informations. Je vous la réserve donc pour les méthodes basées sur les gradients de politiques.

Passons à la partie 3 et construisons la classe de notre agent.