Added type alias for node position.

This commit is contained in:
Filipe Rodrigues 2023-10-02 05:10:25 +01:00
parent 306c26c8e1
commit 8ba26b5a99

View File

@ -14,6 +14,8 @@ import numpy
import arc_proj.util as util
from arc_proj.agent import Agent
# Node position type
NodePos = Tuple[int, int]
@dataclass
class GraphCache:
@ -22,10 +24,10 @@ class GraphCache:
"""
# Unsatisfied nodes
unsatisfied_nodes: set[Tuple[int, int]] = dataclasses.field(default_factory=set)
unsatisfied_nodes: set[NodePos] = dataclasses.field(default_factory=set)
# Empty nodes
empty_nodes: set[Tuple[int, int]] = dataclasses.field(default_factory=set)
empty_nodes: set[NodePos] = dataclasses.field(default_factory=set)
class Graph:
@ -69,7 +71,7 @@ class Graph:
# Initialize caches
self.cache.empty_nodes = set(node_pos for node_pos in self.graph.nodes)
def add_agent(self, node_pos: Tuple[int, int], agent: Agent):
def add_agent(self, node_pos: NodePos, agent: Agent):
"""
Adds an agent `agent` at node `node_pos`.
@ -84,7 +86,7 @@ class Graph:
self.cache.empty_nodes.remove(node_pos)
self.update_unsatisfied_nodes_cache(node_pos, skip_node=False)
def remove_agent(self, node_pos: Tuple[int, int]) -> Agent:
def remove_agent(self, node_pos: NodePos) -> Agent:
"""
Removes an agent at node `node_pos`.
@ -131,7 +133,7 @@ class Graph:
return agents
def update_unsatisfied_nodes_cache(self, node_pos: Tuple[int, int], skip_node: bool):
def update_unsatisfied_nodes_cache(self, node_pos: NodePos, skip_node: bool):
"""
Updates the unsatisfied nodes cache for the node `node_pos` and neighbors.
@ -166,7 +168,7 @@ class Graph:
agent = numpy.random.choice(list(agent_weights.keys()), p=list(agent_weights.values()))
self.add_agent(node_pos, agent)
def agent_satisfaction(self, node_pos: Tuple[int, int]) -> float | None:
def agent_satisfaction(self, node_pos: NodePos) -> float | None:
"""
Returns the satisfaction of an agent, from 0.0 to 1.0.