强化学习--Sarsa
系列文章目录强化学习提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档文章目录系列文章目录前言一、强化学习是什么?二、核心算法(免模型学习) Sarsa1.学习心得前言强化学习(Reinforcement Learning, RL),又称再励学习、评价学习或增强学习,是机器学习的范式和方法论之一,用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现
系列文章目录
强化学习
提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
强化学习(Reinforcement Learning, RL),又称再励学习、评价学习或增强学习,是机器学习的范式和方法论之一,用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现特定目标的问题 。
一、强化学习是什么?
强化学习是智能体(Agent)以“试错”的方式进行学习,通过与环境进行交互获得的奖赏指导行为,目标是使智能体获得最大的奖赏,强化学习不同于连接主义学习中的监督学习,主要表现在强化信号上,强化学习中由环境提供的强化信号是对产生动作的好坏作一种评价(通常为标量信号),而不是告诉强化学习系统RLS(reinforcement learning system)如何去产生正确的动作。由于外部环境提供的信息很少,RLS必须靠自身的经历进行学习。通过这种方式,RLS在行动-评价的环境中获得知识,改进行动方案以适应环境。
理解:强化学习其实就是和人一样,一开始是什么都不懂的,所谓吃一堑长一智,他像一个新生的孩子,它在不断的试错过程中慢慢知道了做什么有奖励,做什么对得到奖励会有一定的价值,做什么会被打。在这个过程中不会像监督学习一样有个师傅带你,完全需要自己去摸索,就像修仙宗门一样,有背景的宗门弟子是继承掌门之位(监督),创立宗门的人是开山立派(强化),必须一步一个脚印去不断成长。
其实强化学习吸引我的就是因为它主要使用在游戏上,例如:
在 Flappy bird 这个游戏中,我们需要简单的点击操作来控制小鸟,躲过各种水管,飞的越远越好,因为飞的越远就能获得更高的积分奖励。
机器有一个玩家小鸟——Agent
需要控制小鸟飞的更远——目标
整个游戏过程中需要躲避各种水管——环境
躲避水管的方法是让小鸟用力飞一下——行动
飞的越远,就会获得越多的积分——奖励
二、核心算法(免模型学习) Sarsa
1.学习心得
总的理解:sarsa和QLearning很像,唯一不同就是off-policy与on-policy的不同,QLearning很勇敢最终会找到一个最近的路,而Sarsa会选择绕过危险,达到目的即可。也就是超级玛丽之类的游戏,有的人喜欢去破时间的记录,有的人却为了稳稳地过关就好。
不一样的部分:(e_greedy贪婪在强化学习中,这两个算法中的取值0.9,意思就是90%贪婪,剩下的随机探索,但是e_greedy越小效果会越好,)
class RL(object):
def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.actions = action_space # a list
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state,
)
)
def choose_action(self, observation):
self.check_state_exist(observation)
# action selection
if np.random.rand() < self.epsilon:
# choose best action
state_action = self.q_table.loc[observation, :]
# some actions may have the same value, randomly choose on in these actions
action = np.random.choice(state_action[state_action == np.max(state_action)].index)
else:
# choose random action
action = np.random.choice(self.actions)
return action
def learn(self, *args):
pass
# off-policy
class QLearningTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
#当前位置,当前动作,当前价值,下一个位置
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max() # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
# on-policy
class SarsaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
#当前位置,当前动作,当前价值,下一个位置,下一个动作
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
单步更新和回合更新的区别:
- 单步更新是在得到奖励时认为奖励之前那一步的关系很大,其他没关联。
- 回合更新是在得到奖励时认为奖励之前所有的步骤都是有关联的。
Sarsa是单步更新法,提速方法Lambda:
lambda(衰减率回合制更新法),可以理解成目标离agent的距离衰减情况,衰减的意思就像我们人站在目标的位置看agent,越近看的越清楚(衰减小),越远就越模糊(衰减大) - 当Sarsa(0)时,为单步更新,因为没得衰减,我只认为最近的一步有关系
-
- 当Sarsa(lambda)时,衰减情况,可以调节的参数
-
- 当Sarsa(1)时,为回合更新
# on-policy
class SarsaLambdaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
self.lambda_=trace_decay
self.eligibility_trace=self.q_table.copy()
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state,
)
)
self.eligibility_trace=self.eligibility_trace.append(to_be_append)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
error=q_target - q_predict
# 1 无上限
# self.eligibility_trace.loc[s,a]+=1
# 2 上限为1
self.eligibility_trace.loc[s,:]*=0
self.eligibility_trace.loc[s,a]=1
# 更新方式
self.q_table.loc[s, a] += self.lr * error*self.eligibility_trace # update
# 衰变方式
self.eligibility_trace*=self.gamma*self.lambda
更多推荐
所有评论(0)