gridworld.models package

Submodules

gridworld.models.action module

class gridworld.models.action.Action[source]

Bases: enum.Enum

Models all possible actions.

down = 1
left = 3
right = 2
up = 0

gridworld.models.agent module

class gridworld.models.agent.Agent[source]

Bases: object

Models a Markov Decision Process-based agent.

Parameters
  • policy (numpy.ndarray) –

  • state_value_function (numpy.ndarray) –

  • q_function (dict) –

  • environment (gridworld.models.World) –

evaluate_policy(policy: numpy.ndarray) → numpy.ndarray[source]

Evaluates a policy (Q value for each cell).

Parameters

policy (numpy.ndarray) – Policy to be evaluated using agents estimated Q function.

Returns

state_value_function

Return type

numpy.ndarray

improve_policy(value_function, gamma) → numpy.ndarray[source]

Computes a new policy for the given values.

Parameters
  • value_function (numpy.ndarray) – Maximum Q values obtained with the curren policy

  • gamma (float) – Discount factor

Returns

new_policy

Return type

numpy.ndarray

plot_q_function()[source]

Polts a 3D density plot representing agent’s Q function values distribution.

run_policy_iteration(max_iterations: int = 200000, gamma: float = 1.0) → None[source]

Estimates optimal policy function.

Parameters
  • max_iterations (float) – Maximum number of iterations when looking for policy function.

  • gamma (float) – Discount factor

run_value_iteration(max_iterations: int = 100000, threshold: float = 1e-20, gamma=0.99) → None[source]

Estimates optimal state-value function.

Parameters
  • max_iterations (float) – Maximum number of iterations when looking for state-value function.

  • threshold (float) – Minimum change that should happen to continue value search iteration.

  • gamma (float) – Discount factor

solve()[source]

Solves the board using all estimated parameters.

gridworld.models.game module

class gridworld.models.game.Game(world_width: int, world_height: int, start_cell: int, goal_cell: int, obstacles_cells: List[int])[source]

Bases: object

Models a simple game where an agent tries to solve a grid world with the given configuration

Parameters
  • world (gridworld.models.World) – Board to be solved.

  • agent (gridworld.models.Agent) – Agent that will solve the board.

play(policy_search_iterations: int = 200000, value_search_iterations: int = 100000, threshold: float = 1e-20, gamma: float = 0.8) → Tuple[list, bool][source]

Makes agent solve the board.

Parameters
  • policy_search_iterations (int) – Maximum number of iterations when looking for optimal policy.

  • value_search_iterations (int) – Maximum number of iterations when looking for optimal state-value function.

  • threshold (float) – Minimum change that should happen to continue value search iteration.

  • gamma (float) – Discount factor

Returns

  • player_positions (list) – Cells the agent followed.

  • reached_goal (bool) – True if agent reached goal cell successfully, False otherwise.

gridworld.models.reward module

class gridworld.models.reward.Reward[source]

Bases: enum.Enum

Models all possible rewards.

goal = 1
obstacle = -1
road = 0
start = 0

gridworld.models.state module

class gridworld.models.state.State(cell_id: int, possible_moves: dict, cell_type: int)[source]

Bases: object

Models states in a MDP.

cell_id
Type

int

cell_type
Type

int

actions
Type

dict

get_action_results(action: gridworld.models.action.Action) → dict[source]

Retrieves the results of performing specified action in current state.

Parameters

action (gridworld.models.Action) – Action to be performed.

Returns

results – Format: {cell_id: int, reward: gridworld.models.Reward, transition_probability: float, is_goal: bool}

Return type

dict

gridworld.models.world module

class gridworld.models.world.World(grid_width: int = 4, grid_height: int = 4, starting_position: int = 0, goal_position: int = 15, obstacle_positions: list = None)[source]

Bases: object

Class modeling the world.

grid

List with world’s elements.

Type

list

grid_width

Grid’s width in cells.

Type

int

grid_height

Grid’s height in cells

Type

int

starting_position

Cell where the agent will be place at the start.

Type

int

goal_position

Cell where the agent has to go.

Type

int

obstacle_positions

Cells where obstacles will be placed.

Type

list

get_state(cell_id: int) → gridworld.models.state.State[source]

Retrieves the state of the given cell_id.

Parameters

cell_id (int) – ID of the cell.

Returns

cells_state

Return type

gridworld.models.State

print(player_positions: list = None) → list[source]

Prints world’s grid and players positions.

Parameters

player_positions (list) – Cells the player has be in except start and goal cells.

Returns

colored_grid – List with color value of each cell

Return type

list