General Documentation

This page hosts the general documentation of the ActionRNNs.jl library. This includes all research code used in this project.

Contents

Index

Cells

ActionRNNs.AbstractActionRNNType
AbstractActionRNN

An abstract struct which will take the current hidden state and a tuple of observations and actions and returns the next hidden state.

source

Basic Cells

ActionRNNs.AARNNFunction
AARNN(in::Integer, actions::Integer, out::Integer, σ = tanh)

Like an RNN cell, except takes a tuple (action, observation) as input. The action is used with get_waa with results added to the usual update.

The update is as follows: σ.(Wi*o .+ get_waa(Wa, a) .+ Wh*h .+ b)

source
ActionRNNs.AAGRUFunction
AAGRU(in, actions, out)

Additive Action Gated Recurrent Unit layer. Behaves like an AARNN but uses a GRU internal structure

source
ActionRNNs.MARNNFunction
MARNN(in::Integer, actions::Integer, out::Integer, σ = tanh)

This cell incorporates the action as a multiplicative operation. We use contract_WA and get_waa to handle this.

The update is as follows:

new_h = σ.(contract_WA(m.Wx, a, o) .+ contract_WA(m.Wh, a, h) .+ get_waa(m.b, a))
source
ActionRNNs.MAGRUFunction
MAGRU(in, actions, out)

Multiplicative Action Gated Recurrent Unit layer. Behaves like an MARNN but uses a GRU internal structure.

source
ActionRNNs.FacMARNNFunction
FacMARNN(in::Integer, actions::Integer, out::Integer, factors, σ = tanh; init_style="ignore")

This cell incorporates the action as a multiplicative operation, but as a factored approximation of the multiplicative version. This cell uses get_waa. Uses CP decomposition.

The update is as follows:

   new_h = m.σ.(W*((Wx*o .+ Wh*h) .* get_waa(Wa, a)) .+ get_waa(m.b, a))

Three init_styles:

  • standard: using init and initb w/o any keywords
  • ignore: W = init(out, factors, ignore_dims=2)
  • tensor: Decompose W_t = init(actions, out, in+out; ignore_dims=1) to get W_o, W_a, W_hi using TensorToolbox.cp_als.
source
ActionRNNs.FacMAGRUFunction
FacMAGRU(in, actions, out, factors)

Factored Multiplicative Action Gated Recurrent Unit layer. Behaves like an FacMARNN but uses a GRU internal structure.

Three init_styles:

  • standard: using init and initb w/o any keywords
  • ignore: W = init(out, factors, ignore_dims=2)
  • tensor: Decompose W_t = init(actions, out, in+out; ignore_dims=1) to get W_o, W_a, W_hi using TensorToolbox.cp_als.
source
ActionRNNs.FacTucMARNNFunction
FacTucMARNN(in::Integer, actions::Integer, out::Integer, action_factors, out_factors, in_factors, σ = tanh; init_style="ignore")

This cell incorporates the action as a multiplicative operation, but as a factored approximation of the multiplicative version. This cell uses get_waa. Uses Tucker decomposition.

Three init_styles:

  • standard: using init and initb w/o any keywords
  • ignore: Wa = init(action_factors, actions; ignore_dims=2)
source
ActionRNNs.FacTucMAGRUFunction
FacTucMAGRU(in, actions, out, factors)

Factored Multiplicative Action Gated Recurrent Unit layer. Behaves like an FacTucMARNN but uses a GRU internal structure.

source

Combo Cells

ActionRNNs.CaddAAGRUFunction
CaddAAGRU(in, actions, out)

Mixing between two AAGRU cells through weighting

```julia h′ = (w[1]new_hAA1 + w[2]new_hAA2) ./ sum(w)

source
ActionRNNs.CaddMAGRUFunction
CaddMAGRU(in, actions, out)

Mixing between two MAGRU cells through weighting

```julia h′ = (w[1]new_hMA1 + w[2]new_hMA2) ./ sum(w)

source
ActionRNNs.CaddElRNNFunction
CaddElRNN(in, actions, out, σ = tanh)

Mixing between AARNN and MARNN through a weighting

h′ = (AA_θ .* AA_h′ .+ MA_θ .* MA_h′) ./ (AA_θ .+ MA_θ)
source
ActionRNNs.CaddElGRUFunction
CaddElGRU(in, actions, out)

Mixing between AAGRU and MAGRU through a weighting

h′ = (AA_θ .* AA_h′ .+ MA_θ .* MA_h′) ./ (AA_θ .+ MA_θ)
source

Mixed Cells

ActionRNNs.MixRNNFunction
MixRNN(in, actions, out, num_experts, σ = tanh)

Mixing between num_experts AARNN cells. Uses the weighting

h′ = sum(θ[i] .* expert_h′[i] for i in 1:length(θ)) ./ sum(θ)
source
ActionRNNs.MixElRNNFunction
MixElRNN(in, actions, out, num_experts, σ = tanh)

Mixing between num_experts AARNN cells. Uses the weighting

h′ = sum(θ[i] .* expert_h′[i] for i in 1:length(θ)) ./ sum(θ)

(here θ[i] is a vector).

source
ActionRNNs.MixGRUFunction
MixGRU(in, actions, out, num_experts)

Mixing between num_experts AAGRU cells. Uses the weighting

h′ = sum(θ[i] .* expert_h′[i] for i in 1:length(θ)) ./ sum(θ)
source
ActionRNNs.MixElGRUFunction
MixElGRU(in, actions, out, num_experts)

Mixing between num_experts AAGRU cells. Uses the weighting

h′ = sum(θ[i] .* expert_h′[i] for i in 1:length(θ)) ./ sum(θ)

(here θ[i] is a vector).

source
ActionRNNs.ActionGatedRNNFunction
ActionGatedRNN(in::Integer, na, internal, out::Integer, σ = tanh)

The most basic recurrent layer; essentially acts as a Dense layer, but with the output fed back into the input each time step.

source

Old/DefunctCells

ActionRNNs.GAIARNNFunction
GAIARNN(in::Integer, na, internal, out::Integer, σ = tanh)

The most basic recurrent layer; essentially acts as a Dense layer, but with the output fed back into the input each time step.

source

Shared operations for cells

HelpfulKernelFuncs.contract_WAFunction
contract_WA(W, a::Int, x)
contract_WA(W, a::AbstractVector{Int}, x)
contract_WA(W, a::AbstractVector{<:AbstractFloat}, x)
contract_WA(W::CuArray, a::AbstractVector{Int}, x)

This contraction operator will take the weights W, action (or action vector for batches) a, and features. The weight matrix is assumed to be in nactions × out × in.

HelpfulKernelFuncs.get_waaFunction
get_waa(Wa, a)

Different ways of handeling geting action value from a set of weights. This operation can be seen as Wa*a where Wa is the weight matrix, and a is the action representation. This is to be used with various cells to incorporate this operation more reliably.

Other Layers

ActionRNNs.ActionDenseType
ActionDense(in, na, out, σ; init, bias)

Create an actions Dense layer. This layer takes in a tuple (action, observaiton) and returns the dense layer using and additive approach. This can be used for previous actions or current actions.

source

Learning Updates

ActionRNNs.QLearningType
QLearning
QLearningMSE(γ)
QLearningSUM(γ)
QLearningHUBER(γ)

Watkins q-learning with various loss functions.

source

Constructors

ActionRNNs.build_rnn_layerFunction
build_rnn_layer(in, actions, out, parsed, rng)

Build an rnn layer according from parsed. This assumes the "cell" key is in the parsed dict. in, actions, and out are integers. must explicitly pass in a RNG.

Gets layer constructor from either the ActionRNNs or Flux namespaces.

Types of build types

source
build_rnn_layer(::BuildActionRNN, args...; kwargs...)

Standard Additive and Multiplicative cells. No extra parameters.

source
build_rnn_layer(::BuildFactored, args...; kwargs...)

Factored (not tucker) cells. Extra Config Options:

  • init_style::String: They style of init. Check your cell for possible options.
  • factors::Int: Number of factors in factorization.
source
build_rnn_layer(::BuildTucFactored, args...; kwargs...)

Tucker Factored cells: Extra Config Options:

  • in_factors::Int: Number of factors in input matrix
  • action_factors::Int: Number of factors in action matrix
  • out_factors::Int: Number of factors in out matrix
source
build_rnn_layer(::BuildComboCat, args...; kwargs...)

Combo cat AA/MA cells. No Extra Params.

source
build_rnn_layer(::BuildComboAdd, args...; kwargs...)

Combo add AA/MA cells. No Extra Params.

source
build_rnn_layer(::BuildMixed, args...; kwargs...)

Mixed layers. Extra Config Options -num_experts::Int: number of parallel cells in mixture.

source
build_rnn_layer(::BuildFlux, args...; kwargs...)

Flux cell. No extra parameters.

source

Agents

Experience Replay Agents

ActionRNNs.AbstractERAgentType
AbstractERAgent

The abstract struct for building experience replay agents.

example agent: mutable struct DRQNAgent{ER, Φ, Π, HS<:AbstractMatrix{Float32}} <: AbstractERAgent lu::LearningUpdate opt::O model::C target_network::CT

build_features::F
state_list::DataStructures.CircularBuffer{Φ}

hidden_state_init::Dict{Symbol, HS}

replay::ER
update_timer::UpdateTimer
target_update_timer::UpdateTimer

batch_size::Int
τ::Int

s_t::Φ
π::Π
γ::Float32

action::Int
am1::Int
action_prob::Float64

hs_learnable::Bool
beg::Bool
cur_step::Int

hs_tr_init::Dict{Symbol, HS}

end

source

Instantiations

Implementation details

MinimalRLCore.start!Method
    MinimalRLCore.start!(agent::AbstractERAgent, s, rng; kwargs...)

Start the agent for a new episode.

source
MinimalRLCore.step!Method
MinimalRLCore.step!(agent::AbstractERAgent, env_s_tp1, r, terminal, rng; kwargs...)

step! for an experience replay agent.

source
MinimalRLCore.step!Function
MinimalRLCore.step!(agent::AbstractERAgent, env_s_tp1, r, terminal, rng; kwargs...)

step! for an experience replay agent.

source
ActionRNNs.update!Method
update!(agent::AbstractERAgent{<:ControlUpdate}, rng)

Update the parameters of the model.

source
ActionRNNs.update!Method
update!(agent::AbstractERAgent{<:PredictionUpdate}, rng)

Update the parameters of the model.

source
ActionRNNs.update!Function
update!(agent::AbstractERAgent{<:ControlUpdate}, rng)

Update the parameters of the model.

source
update!(agent::AbstractERAgent{<:PredictionUpdate}, rng)

Update the parameters of the model.

source

Online Agents

Tools/Utils

ActionRNNs.make_obs_listFunction
make_obs_list

Makes the obs list and initial state used for recurrent networks in an agent. Uses an init function to define the init tuple.

source
Hidden state manipulation
ActionRNNs.reset!Function
reset!(m, h_init::Dict)
reset!(m::Flux.Recur, h_init)

Reset the hidden state according to the dict hinit with keys from [`gethssymbollist`](@ref). If model is a recur just replace the hidden state.

source
Replay buffer
ActionRNNs.CircularBufferType

CircularBuffer Maintains a buffer of fixed size w/o reallocating and deallocating memory through a circular queue data struct.

source
ActionRNNs.StateBufferType
StateBuffer(size::Int, state_size)

A cicular buffer for states. Typically used for images, can be used for state shapes up to 4d.

source
Base.lengthMethod
length(buffer)

Returns the current amount of data in the circular buffer. If the full flag is true then we return the size of the whole data frame.

source
Base.push!Method
push!(buffer, data)

Adds data to the buffer, where data is an array of collections of types defined in CircularBuffer.datatypes returns row of data of added d

source
ActionRNNs.get_state_from_experienceFunction
get_state_from_experiment

Returns hidden state from experience sampled from an experience replay buffer. This assumes the replay has (:am1, :s, :a, :sp, :r, :t, :beg, hs_symbol...) as columns.

source
ActionRNNs.get_information_from_experienceFunction
get_information_from_experience(agent, exp)

Gets the tuple of required details for the update of the agent. This is dispatched on the type of learning update. You can use the helper abstract classes, or dispatch for your specific update.

source
ActionRNNs.get_hs_from_experience!Function
get_hs_from_experience!(model, exp::NamedTuple, hs_dict::Dict, device)
get_hs_from_experience!(model, exp::Vector, hs_dict::Dict, device)

Get hs in the appropriate formate from the experience (either a Named Tuple or a vector of tuples Named Tuples).

source
Flux Chain Manipulation
ActionRNNs.find_layers_with_eqFunction
find_layers_with_eq(eq::Function, model)

A function which takes a model and a function and returns the locations where the function returns true. This only supports composing chains twice.

source

Policies

ActionRNNs.ϵGreedyDecayType
ϵGreedyDecay{AS}(ϵ_range, decay_period, warmup_steps, action_set::AS)
ϵGreedyDecay(ϵ_range, end_step, num_actions)

This is an acting policy which decays exploration linearly over time. This api will possibly change overtime once I figure out a better way to specify decaying epsilon.

Arguments

ϵ_range::Tuple{Float64, Float64}: (max epsilon, min epsilon) decay_period::Int: period epsilon decays warmup_steps::Int: number of steps before decay starts

source

Feature Constructors

Environments

RingWorld

ActionRNNs.RingWorldType

RingWorld States: 1 2 3 ... n Vis: 1 <-> 0 <-> 0 <-> ... <-> 0 <-| ^––––––––––––––-|

chain_length: size (diameter) of ring actions: Forward of Backward

source

LinkedChains

ActionRNNs.LinkedChainsV2Type
LinkedChains

termmode:

  • CONT: No termination
  • TERM: Terminate after chain

dynmode:

  • STRAIGHT: high Negative reward on wrong actions, but still progress through chain
  • JUMP: Jump to different chain on wrong action
  • STUCK: Don't progress on wrong action
  • JUMPSTUCK: Get "lost" with wrong actions, still being implemented.
source

TMaze

DirectionalTMaze

ActionRNNs.DirectionalTMazeType
DirectionalTMaze

Similar to ActionRNNs.TMaze but with a directional componenet overlayed ontop. This also changes to observation structure, where the agent must know what direction it is facing to get information about which goal is the good goal.

source

Masked Grid World

ActionRNNs.MaskedGridWorldType
MaskedGridWorld

This grid world gives observations on a random number of states which are aliased (or not given obsstrategy). This environment also has the pacmanwrapping flag which makes it so the edges wrap around.

  • width::Int: width of gw
  • height::Int: height of gw
  • anchors::Int: number of anchors (Int), or list of anchor states
  • goals_or_rews: number of goals, list of goals, or list of rewards.
  • obs_strategy: what obs are returned, :seperate, :full, aliased
  • pacman_wrapping::Bool: whether the walls are invisible and wrap around
source

Lunar Lander

FluxUtils Stuff

ActionRNNs.ExpUtils.FluxUtils.get_optimizerFunction
get_optimizer

Return the Flux optimizer given a config dictionary. The optimizer name is found at key "opt". The parameters also change based on the optimizer.

  • OneParamInit: eta::Float
  • TwoParamInit: eta::Float, rho::Float
  • AdamParamInit: eta::Float, beta::Vector or (beta_m::Int, beta_v::Int)
source

Misc