gridworld.models package¶
Submodules¶
gridworld.models.action module¶
gridworld.models.agent module¶
-
class
gridworld.models.agent.
Agent
[source]¶ Bases:
object
Models a Markov Decision Process-based agent.
- Parameters
policy (numpy.ndarray) –
state_value_function (numpy.ndarray) –
q_function (dict) –
environment (gridworld.models.World) –
-
evaluate_policy
(policy: numpy.ndarray) → numpy.ndarray[source]¶ Evaluates a policy (Q value for each cell).
- Parameters
policy (numpy.ndarray) – Policy to be evaluated using agents estimated Q function.
- Returns
state_value_function
- Return type
numpy.ndarray
-
improve_policy
(value_function, gamma) → numpy.ndarray[source]¶ Computes a new policy for the given values.
- Parameters
value_function (numpy.ndarray) – Maximum Q values obtained with the curren policy
gamma (float) – Discount factor
- Returns
new_policy
- Return type
numpy.ndarray
-
plot_q_function
()[source]¶ Polts a 3D density plot representing agent’s Q function values distribution.
-
run_policy_iteration
(max_iterations: int = 200000, gamma: float = 1.0) → None[source]¶ Estimates optimal policy function.
- Parameters
max_iterations (float) – Maximum number of iterations when looking for policy function.
gamma (float) – Discount factor
-
run_value_iteration
(max_iterations: int = 100000, threshold: float = 1e-20, gamma=0.99) → None[source]¶ Estimates optimal state-value function.
- Parameters
max_iterations (float) – Maximum number of iterations when looking for state-value function.
threshold (float) – Minimum change that should happen to continue value search iteration.
gamma (float) – Discount factor
gridworld.models.game module¶
-
class
gridworld.models.game.
Game
(world_width: int, world_height: int, start_cell: int, goal_cell: int, obstacles_cells: List[int])[source]¶ Bases:
object
Models a simple game where an agent tries to solve a grid world with the given configuration
- Parameters
world (gridworld.models.World) – Board to be solved.
agent (gridworld.models.Agent) – Agent that will solve the board.
-
play
(policy_search_iterations: int = 200000, value_search_iterations: int = 100000, threshold: float = 1e-20, gamma: float = 0.8) → Tuple[list, bool][source]¶ Makes agent solve the board.
- Parameters
policy_search_iterations (int) – Maximum number of iterations when looking for optimal policy.
value_search_iterations (int) – Maximum number of iterations when looking for optimal state-value function.
threshold (float) – Minimum change that should happen to continue value search iteration.
gamma (float) – Discount factor
- Returns
player_positions (list) – Cells the agent followed.
reached_goal (bool) – True if agent reached goal cell successfully, False otherwise.
gridworld.models.reward module¶
gridworld.models.state module¶
-
class
gridworld.models.state.
State
(cell_id: int, possible_moves: dict, cell_type: int)[source]¶ Bases:
object
Models states in a MDP.
-
cell_id
¶ - Type
int
-
cell_type
¶ - Type
int
-
actions
¶ - Type
dict
-
get_action_results
(action: gridworld.models.action.Action) → dict[source]¶ Retrieves the results of performing specified action in current state.
- Parameters
action (gridworld.models.Action) – Action to be performed.
- Returns
results – Format: {cell_id: int, reward: gridworld.models.Reward, transition_probability: float, is_goal: bool}
- Return type
dict
-
gridworld.models.world module¶
-
class
gridworld.models.world.
World
(grid_width: int = 4, grid_height: int = 4, starting_position: int = 0, goal_position: int = 15, obstacle_positions: list = None)[source]¶ Bases:
object
Class modeling the world.
-
grid
¶ List with world’s elements.
- Type
list
-
grid_width
¶ Grid’s width in cells.
- Type
int
-
grid_height
¶ Grid’s height in cells
- Type
int
-
starting_position
¶ Cell where the agent will be place at the start.
- Type
int
-
goal_position
¶ Cell where the agent has to go.
- Type
int
-
obstacle_positions
¶ Cells where obstacles will be placed.
- Type
list
-