文档章节

QBrain

二胡艺
 二胡艺
发布于 2017/04/18 17:21
字数 474
阅读 26
收藏 0

输入图片说明

import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D,MaxPool2D,Activation,Dense,Flatten,Convolution2D
from collections import deque
class QBrain:
    def __init__(self,num_act = 18,capacity = 10000,esispde = 1000,batch_size = 32,esp=0.3,gama = 0.7):
        self.ACTION = num_act
        self.CAPACITY_MEM = capacity
        self.ESISODE = esispde
        self.batch_size = batch_size
        self.esplion = esp
        self.GAMA = gama
        self.replayMem = deque(maxlen=self.CAPACITY_MEM)
        self.OBSERVE = 10
        self.SKIP_FRAME = 4
        self.time_step=0
    def createQNet(self):
        self.model = Sequential()
        self.model.add(Conv2D(32,(5,5),strides=(1,1),batch_input_shape=(self.batch_size,80,80,3),data_format='channels_last'))
        self.model.add(Activation('relu'))
        
        self.model.add(MaxPool2D((2,2)))
        self.model.add(Activation('relu'))
        
        self.model.add(Conv2D(64,(5,5),strides=(1,1)))
        self.model.add(Activation('relu'))
        
        self.model.add(MaxPool2D((2,2)))
        self.model.add(Activation('relu'))
        
        self.model.add(Conv2D(128,(5,5),strides=(1,1)))
        self.model.add(Activation('relu'))
        
        self.model.add(MaxPool2D((2,2)))
        self.model.add(Activation('relu'))
        
        self.model.add(Flatten())
        self.model.add(Dense(1000,activation=Activation('relu')))
        self.model.add(Dense(1000,activation=Activation('relu')))
        self.model.add(Dense(self.ACTION,activation=Activation('relu')))
        
        self.model.compile(loss='mse',optimizer='sgd')
                
    def trainNet(self):
        #fetch data from replayMem
        replayItem = [self.replayMem[np.random.randint(len(self.replayMem))] for i in range(self.batch_size)]
        minibatch_state = np.array([iter[0] for iter in replayItem])
        minibatch_action = np.array([iter[1] for iter in replayItem])
        minibatch_reward = np.array([iter[2] for iter in replayItem])
        minibatch_state_next = np.array([iter[3] for iter in replayItem])
        minibatch_terminal = np.array([iter[4] for iter in replayItem])
        
        Q_values = self.model.predict(minibatch_state_next,batch_size=self.batch_size)
        #print(Q_values)
        y_batch_lst = []
        for i in range(self.batch_size):
            if minibatch_terminal[i]:
                y_batch_lst.append(minibatch_reward[i])
            else:
                y_batch_lst.append(minibatch_reward[i] + self.GAMA*(np.max(Q_values[i])))
        y_batch = np.array(y_batch_lst)
        #print(y_batch.shape)
        rsp = self.model.train_on_batch(minibatch_state,y_batch)
        print(rsp)
    def getAction(self):        
        rand_seed = np.random.random()
        if rand_seed < self.esplion:
            self.action = np.zeros((1,self.ACTION))
            self.action[0,np.random.randint(self.ACTION)] = 1
        else:
            batch_state_list = []
            for i in range(self.batch_size):
                batch_state_list.append(self.currState)
            batch_state = np.array(batch_state_list)
            self.action = self.model.predict(batch_state)
        best_row = [self.action[i,:].max() for i in range(self.action.shape[0])]        
        return self.action[np.array(best_row).argmax(),:]
    def setSequnceState(self,nextState,reward,terminate):
        self.replayMem.append([self.currState,self.action,reward,nextState,terminate])
        if self.time_step > self.OBSERVE:
            if self.time_step % self.SKIP_FRAME == 0:
                print("Enter Net")
                self.trainNet()
        self.time_step += 1
        self.currState = nextState        
        
    def initSequnceState(self,state):
        self.action = np.zeros((1,self.ACTION))
        self.action[0,np.random.randint(self.ACTION)] = 1
        self.currState = state        
        best_row = [self.action[i,:].max() for i in range(self.action.shape[0])]        
        return self.action[np.array(best_row).argmax(),:]
%matplotlib inline
from matplotlib import pyplot as plt
from ale_python_interface import ALEInterface
import cv2
 
def imgpreprocess(img,size=(80,80)):
    img = cv2.resize(img,size)
    return img
def normalize(img):
    im = img.astype(np.float)
    return (im - im.mean()) / 255

agent = QBrain()
agent.createQNet()
game_path = b'F:/github/Arcade-GAME2/Breakout.a26'
ale = ALEInterface()
ale.setInt(b'random_seed',123)
ale.setBool(b'display_screen',True)
ale.setInt(b'frame_skip',4)
ale.loadROM(game_path)
legal_actins = ale.getLegalActionSet()
w,h = ale.getScreenDims()
screenData = np.empty((h,w,3),dtype=np.uint8)
ale.getScreenRGB(screenData)
init_state = imgpreprocess(screenData)
init_state = normalize(init_state)
#plt.imshow(init_state)
action = agent.initSequnceState(init_state)
total_reward = 0
k = 0
while not ale.game_over():
    #reward = ale.act(np.random.randint(18))
    reward = ale.act(action.argmax())
    ale.getScreenRGB(screenData)
    next_state = imgpreprocess(screenData)
    next_state = normalize(next_state)
    arr_reward = np.zeros_like(action)
    arr_reward[action.argmax()] = 1 if reward > 0 else 0 
    agent.setSequnceState(next_state,arr_reward,False)
    action = agent.getAction()
    total_reward += reward
    print("action: " , action.argmax(), "   total_reward: " , total_reward)

© 著作权归作者所有

二胡艺
粉丝 4
博文 42
码字总数 12555
作品 0
镇江
程序员
私信 提问

暂无文章

mysql-connector-java升级到8.0后保存时间到数据库出现了时差

在一个新项目中用到了新版的mysql jdbc 驱动 <dependency>     <groupId>mysql</groupId>     <artifactId>mysql-connector-java</artifactId>     <version>8.0.18</version> ......

ValSong
14分钟前
3
0
Spring Boot 如何部署到 Linux 中的服务

打包完成后的 Spring Boot 程序如何部署到 Linux 上的服务? 你可以参考官方的有关部署 Spring Boot 为 Linux 服务的文档。 文档链接如下: https://docs.ossez.com/spring-boot-docs/docs/r...

honeymoose
16分钟前
2
0
Spring Boot 2 实战:使用 Spring Boot Admin 监控你的应用

1. 前言 生产上对 Web 应用 的监控是十分必要的。我们可以近乎实时来对应用的健康、性能等其他指标进行监控来及时应对一些突发情况。避免一些故障的发生。对于 Spring Boot 应用来说我们可以...

码农小胖哥
今天
6
0
ZetCode 教程翻译计划正式启动 | ApacheCN

原文:ZetCode 协议:CC BY-NC-SA 4.0 欢迎任何人参与和完善:一个人可以走的很快,但是一群人却可以走的更远。 ApacheCN 学习资源 贡献指南 本项目需要校对,欢迎大家提交 Pull Request。 ...

ApacheCN_飞龙
今天
4
0
CSS定位

CSS定位 relative相对定位 absolute绝对定位 fixed和sticky及zIndex relative相对定位 position特性:css position属性用于指定一个元素在文档中的定位方式。top、right、bottom、left属性则...

studywin
今天
7
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部