diff --git a/src/arc_proj/graph.py b/src/arc_proj/graph.py index bf3227f..8658991 100644 --- a/src/arc_proj/graph.py +++ b/src/arc_proj/graph.py @@ -46,11 +46,16 @@ class Graph: """ # Inner graph + # Note: Only used for edge-lookups, no values are stored + # within it. graph: nx.Graph # Graph size size: Tuple[int, int] + # Agents + agents: dict[NodePos, Agent] + # Random state random_state: numpy.random.RandomState @@ -68,6 +73,7 @@ class Graph: # Create the graph and initialize it to empty self.size = graph_size self.graph: nx.Graph = nx.grid_2d_graph(graph_size[0], graph_size[1]) + self.agents = dict() self.random_state = numpy.random.RandomState(seed) self.debug = DebugOptions( sanity_check_caches=False, @@ -100,19 +106,17 @@ class Graph: continue # Get the agent in the node (but try our agent cache first, because we've might overridden it already) - agent = agents_cache.pop(src_pos, None) - if agent is None: - src_node = self.graph.nodes[src_pos] - assert 'agent' in src_node, f"Node position {src_node} did not have an agent" - agent = src_node['agent'] - del src_node['agent'] + src_agent = agents_cache.pop(src_pos, None) + if src_agent is None: + src_agent = self.agents.pop(src_pos, None) + assert src_agent is not None, f"Node position {src_pos} did not have an agent" self.cache.empty_nodes.add(src_pos) # Then save the agent in the to slot (in case we need it later), then write our agent into it - dst_node = self.graph.nodes[dst_pos] - if 'agent' in dst_node: - agents_cache[dst_pos] = dst_node['agent'] - dst_node['agent'] = agent + dst_agent = self.agents.get(dst_pos, None) + if dst_agent is not None: + agents_cache[dst_pos] = dst_agent + self.agents[dst_pos] = src_agent self.cache.empty_nodes.discard(dst_pos) assert len(agents_cache) == 0, f"Destination nodes overlapped or destination node already had an agent: {agents_cache}" @@ -128,9 +132,8 @@ class Graph: """ # Set it on the graph - node = self.graph.nodes[node_pos] - assert 'agent' not in node, f"Node position {node_pos} already had an agent: {node['agent']}" - node['agent'] = agent + assert node_pos not in self.agents, f"Node position {node_pos} already had an agent: {self.agents[node_pos]}" + self.agents[node_pos] = agent # Then update the caches self.cache.empty_nodes.remove(node_pos) @@ -144,10 +147,8 @@ class Graph: """ # Remove it from the graph - node = self.graph.nodes[node_pos] - assert 'agent' in node, f"Node position {node_pos} did not have agent" - agent = node['agent'] - del node['agent'] + agent = self.agents.pop(node_pos, None) + assert agent is not None, f"Node position {node_pos} did not have agent" # Then update the caches self.cache.empty_nodes.add(node_pos) @@ -204,7 +205,7 @@ class Graph: # Select a random agent agent = self.random_state.choice(list(agent_weights.keys()), p=list(agent_weights.values())) - self.graph.nodes[node_pos]['agent'] = agent + self.agents[node_pos] = agent self.cache.empty_nodes.remove(node_pos) # At the end, update the remaining caches @@ -218,19 +219,18 @@ class Graph: """ # If the node is empty, it doesn't have a satisfaction - node = self.graph.nodes[node_pos] - if 'agent' not in node: + agent = self.agents.get(node_pos, None) + if agent is None: return None - agent: Agent = node['agent'] # Else count all neighbors that aren't empty neighbors: list[Agent] = [] for neighbor_pos in self.graph.adj[node_pos]: - neighbor_node = self.graph.nodes[neighbor_pos] - if 'agent' not in neighbor_node: + neighbor_agent = self.agents.get(neighbor_pos, None) + if neighbor_agent is None: continue - neighbors.append(neighbor_node['agent']) + neighbors.append(neighbor_agent) return agent.satisfaction(neighbors) @@ -246,7 +246,7 @@ class Graph: # Then check with the agent # Note: Since `satisfaction` isn't `None`, we know it must exist - agent: Agent = self.graph.nodes[node_pos]['agent'] + agent = self.agents[node_pos] return satisfaction >= agent.threshold() @@ -255,8 +255,8 @@ class Graph: Returns an image of all agents in the graph """ img = [[(0, 0, 0) for _ in range(self.size[0])] for _ in range(self.size[1])] - for node_pos, node in self.graph.nodes(data = True): - agent: Agent | None = util.try_index_dict(node, 'agent') + for node_pos in self.graph.nodes: + agent = self.agents.get(node_pos, None) img[node_pos[1]][node_pos[0]] = agent.color() if agent is not None else [0.5, 0.5, 0.5] return img @@ -266,7 +266,7 @@ class Graph: Returns an image of the satisfaction of all agents in the graph """ img = [[(0, 0, 0) for _ in range(self.size[0])] for _ in range(self.size[1])] - for node_pos, node in self.graph.nodes(data=True): + for node_pos in self.graph.nodes: satisfaction = self.agent_satisfaction(node_pos) satisfied = self.agent_satisfied(node_pos) img[node_pos[1]][node_pos[0]] = [satisfaction, 0.0, satisfied] if satisfaction is not None else [1.0, 0.0, 1.0] @@ -279,7 +279,7 @@ class Graph: """ node_pos = { node_pos: node_pos for node_pos in self.graph.nodes()} - agents = (util.try_index_dict(node, 'agent') for _, node in self.graph.nodes(data=True)) + agents = (self.agents.get(node_pos, None) for node_pos in self.graph.nodes) node_colors = [agent.color() if agent is not None else [0.5, 0.5, 0.5] for agent in agents] # And finally draw the nodes and edges @@ -304,7 +304,7 @@ class Graph: case True | None: assert node_pos not in self.cache.unsatisfied_nodes, f"Node {node_pos} was satisfied, but present in unsatisfied cache" case False : assert node_pos in self.cache.unsatisfied_nodes, f"Node {node_pos} wasn't satisfied, but not present in unsatisfied cache" - if 'agent' not in self.graph.nodes[node_pos]: + if node_pos not in self.agents: assert node_pos in self.cache.empty_nodes, f"Node {node_pos} was empty, but not present in empty cache" # Move all the current unsatisfied to another place