-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathreplay_memory.py
More file actions
executable file
·107 lines (94 loc) · 4.13 KB
/
replay_memory.py
File metadata and controls
executable file
·107 lines (94 loc) · 4.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Code is modified from https://github.com/tambetm/simple_dqn/blob/master/src/replay_memory.py"""
import os
import random
import logging
import numpy as np
from utils import save_npy, load_npy
class ReplayMemory:
def __init__(
self,
data_format,
memory_size,
screen_height,
screen_width,
history_length,
batch_size,
model_dir
):
self.model_dir = model_dir
self.data_format = data_format
self.memory_size = memory_size
self.actions = np.empty(self.memory_size, dtype = np.uint8)
self.rewards = np.empty(self.memory_size, dtype = np.integer)
self.screens = np.empty((self.memory_size, screen_height, screen_width), dtype = np.float16)
self.terminals = np.empty(self.memory_size, dtype = np.bool)
self.history_length = history_length
self.dims = (screen_height, screen_width)
self.batch_size = batch_size
self.count = 0
self.current = 0
# pre-allocate prestates and poststates for minibatch
self.prestates = np.empty((self.batch_size, self.history_length) + self.dims, dtype = np.float16)
self.poststates = np.empty((self.batch_size, self.history_length) + self.dims, dtype = np.float16)
def add(self, screen, reward, action, terminal):
assert screen.shape == self.dims
# NB! screen is post-state, after action and reward
self.actions[self.current] = action
self.rewards[self.current] = reward
self.screens[self.current, ...] = screen
self.terminals[self.current] = terminal
self.count = max(self.count, self.current + 1)
self.current = (self.current + 1) % self.memory_size
def getState(self, index):
assert self.count > 0, "replay memory is empy, use at least --random_steps 1"
# normalize index to expected range, allows negative indexes
index = index % self.count
# if is not in the beginning of matrix
if index >= self.history_length - 1:
# use faster slicing
return self.screens[(index - (self.history_length - 1)):(index + 1), ...]
else:
# otherwise normalize indexes and use slower list based access
indexes = [(index - i) % self.count for i in reversed(range(self.history_length))]
return self.screens[indexes, ...]
def sample(self):
# memory must include poststate, prestate and history
assert self.count > self.history_length
# sample random indexes
indexes = []
while len(indexes) < self.batch_size:
# find random index
while True:
# sample one index (ignore states wraping over
index = random.randint(self.history_length, self.count - 1)
# if wraps over current pointer, then get new one
if index >= self.current and index - self.history_length < self.current:
continue
# if wraps over episode end, then get new one
# NB! poststate (last screen) can be terminal state!
if self.terminals[(index - self.history_length):index].any():
continue
# otherwise use this index
break
# NB! having index first is fastest in C-order matrices
self.prestates[len(indexes), ...] = self.getState(index - 1)
self.poststates[len(indexes), ...] = self.getState(index)
indexes.append(index)
actions = self.actions[indexes]
rewards = self.rewards[indexes]
terminals = self.terminals[indexes]
if self.data_format == 'NHWC':
return np.transpose(self.prestates, (0, 2, 3, 1)), actions, \
rewards, np.transpose(self.poststates, (0, 2, 3, 1)), terminals
else:
return self.prestates, actions, rewards, self.poststates, terminals
def save(self):
for idx, (name, array) in enumerate(
zip(['actions', 'rewards', 'screens', 'terminals', 'prestates', 'poststates'],
[self.actions, self.rewards, self.screens, self.terminals, self.prestates, self.poststates])):
save_npy(array, os.path.join(self.model_dir, name))
def load(self):
for idx, (name, array) in enumerate(
zip(['actions', 'rewards', 'screens', 'terminals', 'prestates', 'poststates'],
[self.actions, self.rewards, self.screens, self.terminals, self.prestates, self.poststates])):
array = load_npy(os.path.join(self.model_dir, name))