Skip to content

DictPolicy and special Q-learning based on key-value storage #459

@NeroBlackstone

Description

@NeroBlackstone

If we have a discrete space, discrete action, generative MDP.
And states space and actions space are hard to enumerate. But we still want to use the traditional tabular RL algorithm to solve it.
So, I implement a DictPolicy, it used to store state-action pair values. (Sure. Users need to add Base.isequal() and Base.hash() for their state and action type.)

DictPolicy.jl :

struct DictPolicy{P<:Union{POMDP,MDP}, T<:AbstractDict{Tuple,Float64}} <: Policy
    mdp::P
    value_dict::T
end

# Returns the action that the policy deems best for the current state
function action(p::DictPolicy, s)
    available_actions = actions(mdp,s)
    max_action = nothing
    max_action_value = 0
    for a in available_actions
        if haskey(p.value_dict,(s,a))
            action_value = p.value_dict[(s,a)]
            if action_value > max_action_value
                max_action = a
                max_action_value = action_value
            end
        else
            p.value_dict[(s,a)] = 0
        end
    end
    if max_action === nothing
        max_action = available_actions[1]
    end
    return max_action
end

# returns the values of each action at state s in a dict
function actionvalues(p::DictPolicy, s) ::Dict
    available_actions = actions(mdp,s)
    action_dict = Dict()
    for a in available_actions
        haskey(p.value_dict,(s,a)) ? action_dict[a]  = value_dict[(s,a)] : action_dict[a] = 0
    end
    return action_dict
end

function Base.show(io::IO, mime::MIME"text/plain", p::DictPolicy{M}) where M <: MDP
    summary(io, p)
    println(io, ':')
    ds = get(io, :displaysize, displaysize(io))
    ioc = IOContext(io, :displaysize=>(first(ds)-1, last(ds)))
    showpolicy(io, mime, p.mdp, p)
end

Then we have a special Q-learning based on key-value storage, we don't need to enumerate states space and actions space in MDP definition. (okay, most code copy from TabularTDLearning.jl, but change Q-value store and read.

dict_q_learning.jl :

@with_kw mutable struct QLearningSolver{E<:ExplorationPolicy} <: Solver
   n_episodes::Int64 = 100
   max_episode_length::Int64 = 100
   learning_rate::Float64 = 0.001
   exploration_policy::E
   Q_vals::Union{Nothing, Dict{Tuple,Float64}} = nothing
   eval_every::Int64 = 10
   n_eval_traj::Int64 = 20
   rng::AbstractRNG = Random.GLOBAL_RNG
   verbose::Bool = true
end

function solve(solver::QLearningSolver, mdp::MDP)
    rng = solver.rng
    if solver.Q_vals === nothing
        Q = Dict{Tuple,Float64}()
    else
        Q = solver.Q_vals
    end
    exploration_policy = solver.exploration_policy
    sim = RolloutSimulator(rng=rng, max_steps=solver.max_episode_length)

    on_policy = DictPolicy(mdp, Q)
    k = 0
    for i = 1:solver.n_episodes
        s = rand(rng, initialstate(mdp))
        t = 0
        while !isterminal(mdp, s) && t < solver.max_episode_length
            a = action(exploration_policy, on_policy, k, s)
            k += 1
            sp, r = @gen(:sp, :r)(mdp, s, a, rng)
            max_sp_prediction = 0
            for k in keys(Q)
                if sp == k[1] && max_sp_prediction < Q[k]
                    max_sp_prediction = Q[k]
                end
            end
            current_s_prediction = 0 
            haskey(Q,(s,a)) ? (current_s_prediction = Q[(s,a)]) : (Q[(s,a)] = 0)
            Q[(s,a)] += solver.learning_rate * (r + discount(mdp) * max_sp_prediction - current_s_prediction)
            s = sp
            t += 1
        end
        if i % solver.eval_every == 0
            r_tot = 0.0
            for traj in 1:solver.n_eval_traj
                r_tot += simulate(sim, mdp, on_policy, rand(rng, initialstate(mdp)))
            end
            solver.verbose ? println("On Iteration $i, Returns: $(r_tot/solver.n_eval_traj)") : nothing
        end
    end
    return on_policy
end

What's your point of view? Do you have any advice?
Thank you for taking the time to read my issue.
If you think it's meaningful, I can opne a PR and add some test.
It's okay if you think it's meaningless and no versatility. I just finish it for solve my MDP.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions