代码参考莫烦python

main.m

clear
clc
N_STATE = 6;%开始的距离即状态
ACTIONS = {'left','right'};
EPSILON = 0.9;%贪婪策略
ALPHA = 0.1;%学习率
GAMMA = 0.9;%衰减因子
MAX_EPISODES = 13;%最大学习次数
 
q_table = build_q_table(N_STATE, ACTIONS);
for episode = 1 : MAX_EPISODES
    step_counter = 1;
    S = 1;
    is_terminal = false;
    update_env(S, episode, step_counter);
    
    while is_terminal == false
        A = choose_action(EPSILON, S, q_table, ACTIONS);%返回动作名称
        [S_, R] = get_env_feedback(S, A, N_STATE);%探索环境后的反馈
        if strcmp(A, 'right')
            tmp = 2;
        else
            tmp = 1;
        end
        q_predict = q_table{S,tmp};%选择动作A后的奖励值
        if S_ ~= 6
            q_target = R + GAMMA * max(q_table{S_,:});%实际上会选择的动作后的q值
        else
            q_target = R;
            is_terminal = true;
        end  
        q_table{S,tmp} = q_table{S,tmp} + ALPHA * (q_target - q_predict);
        S = S_;
        update_env(S, episode, step_counter+1);%更新环境
        step_counter = step_counter + 1; 
    end
end

以下为功能函数

get_env_feedback

function [S_, R] = get_env_feedback(S, A, N_STATE)
    if strcmp(A, 'right')
        if S == N_STATE - 1
           S_ = 6;
           R = 1;
        else
           S_ = S + 1;
           R = 0;
        end
    else     %left
        R = 0;
        if S == 1
            S_ = S;
        else
            S_ = S - 1;
        end
    end
end            

choose_action

function action_name = choose_action(EPSILON,state, table, ACTIONS)%选择动作 往左还是往右
    
    if (rand > EPSILON) || (any([table{state, 1} table{state, 2}]) == 0)
        action_name = ACTIONS{randsrc(1,1,randperm(2))};
    else
        [~, index] = max([table{state, 1}, table{state, 2}]);
        action_name = ACTIONS{index};
    end
end

build_q_table

function q_table = build_q_table(n_states, actions)%建立Q表
    sz = [n_states, length(actions)];
    varTypes = {'double', 'double'};
    q_table = table('Size', sz, 'VariableTypes',varTypes,'VariableNames', actions);
end

update_env

function update_env(S, episode, step_counter)%环境更新
    envlist = ['-----T'];
    if S == 6
       fprintf('Episode %d: total_steps = %d\n', episode, step_counter);
    else
       envlist(S) = 'o';
       pause(0.1);
       disp(envlist);
    end
end
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐