import functools
import math
import os.path as osp
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union
import matplotlib.animation as ani
import matplotlib.collections as collections
import matplotlib.figure as figure
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib.axes import Axes
from .arg_parser import parse_args
from .time import Time
if TYPE_CHECKING:
from episimmer.model import BaseModel
from episimmer.simulate import Simulate
[docs]def plot_results(example_path: str, model: 'BaseModel',
avg_dict: Dict[str, List[float]],
stddev_dict: Dict[str,
List[float]], max_dict: Dict[str,
List[int]],
min_dict: Dict[str, List[int]], plot: bool) -> None:
"""
Plots the epidemic trajectory
Args:
example_path: Path to directory containing simulation files
model: Disease model used
avg_dict: Average of epidemic trajectory
stddev_dict: Standard deviation of epidemic trajectory
max_dict: Maximum values of epidemic trajectory across worlds
min_dict: Minimum values of epidemic trajectory across worlds
plot: Boolean used to plot the epidemic trajectory
"""
for state in avg_dict.keys():
x = np.arange(0, len(avg_dict[state]))
plt.plot(avg_dict[state], color=model.colors[state])
# y=np.array(avg_dict[state])
# error=np.array(stddev_dict[state])
plt.fill_between(x,
min_dict[state],
max_dict[state],
alpha=0.2,
facecolor=model.colors[state],
linewidth=0)
plt.title(model.name + ' Plot')
plt.legend(list(avg_dict.keys()), loc='upper right', shadow=True)
plt.ylabel('Population')
plt.xlabel('Time Steps (in unit steps)')
plt.grid(b=True, which='major', color='#666666', linestyle='-')
plt.minorticks_on()
plt.grid(b=True, which='minor', color='#999999', linestyle='-', alpha=0.2)
fig = plt.gcf()
fig.set_size_inches(8, 5)
if plot:
plt.show()
fig.savefig(osp.join(example_path, 'results', 'results.jpg'))
[docs]def buildgraph(i: int, model: 'BaseModel',
avg_dict: Dict[str, List[float]]) -> None:
"""
Builds the epidemic trajectory graph for current frame i
Args:
i: Current frame
model: Disease model used
avg_dict: Average of epidemic trajectory
"""
plt.clf()
plt.title(model.name + ' Plot')
plt.ylabel('Population')
plt.xlabel('Time Steps (in unit steps)')
plt.grid(b=True, which='major', color='#666666', linestyle='-')
plt.minorticks_on()
plt.grid(b=True, which='minor', color='#999999', linestyle='-', alpha=0.2)
for state in avg_dict.keys():
plt.plot(avg_dict[state][:i], label=state, color=model.colors[state])
plt.legend(loc='upper left', shadow=True)
[docs]def store_animated_time_plot(example_path: str, model: 'BaseModel',
avg_dict: Dict[str, List[float]]) -> None:
"""
Saves the animation of epidemic trajectory to a gif file
Args:
example_path: Path to directory containing simulation files
model: Disease model used
avg_dict: Average of epidemic trajectory
"""
fig = plt.figure()
fig.set_size_inches(8, 5)
anim = ani.FuncAnimation(fig,
buildgraph,
interval=100,
fargs=(model, avg_dict))
anim.save(osp.join(example_path, 'results', 'time_plot.gif'),
writer=ani.PillowWriter(fps=10))
[docs]def get_interaction_graph_from_object(obj: 'Simulate') -> nx.Graph:
"""
Generates the interaction graph from the simulation object
Args:
obj: Simulation object
Returns:
Interaction graph
"""
agents_obj = obj.agents_obj
model = obj.model
locations_obj = obj.locations_obj
number_of_agents = agents_obj.n
root_num = int(math.sqrt(number_of_agents))
agents_dict = agents_obj.agents
infected_states = model.infected_states
g = nx.Graph()
# Agent Nodes
for i, agent in enumerate(agents_dict.values()):
if agent.state in infected_states:
g.add_node(agent.index,
color=model.colors[agent.state],
pos=(500 * (i % root_num), 500 * (i / root_num)))
else:
g.add_node(agent.index,
color=model.colors[agent.state],
pos=(500 * (i % root_num), 500 * (i / root_num)))
# Interactions
for agent in agents_dict.values():
if agent.can_contribute_infection > 0:
for int_agent in agent.contact_list:
int_agent_indx = int_agent['Interacting Agent Index']
if (agents_obj.agents[int_agent_indx].can_receive_infection >
0):
g.add_edge(agent.index, int_agent_indx, color='black')
# Events
for j, location in enumerate(locations_obj.locations.values()):
if not location.lock_down_state:
for i, event_info in enumerate(location.events):
g.add_node(event_info['Location Index'] + '_event' + str(i),
color='#40E0D0',
pos=(-1500 - 500 * j, 500 * i))
for agent in event_info['Agents']:
if (agents_obj.agents[agent].can_receive_infection > 0 or
agents_obj.agents[agent].can_contribute_infection >
0):
g.add_edge(event_info['Location Index'] + '_event' +
str(i),
agent,
color='black')
return g
[docs]def save_env_graph() -> Callable:
"""
Decorator to save the interactions graph to the :class:`~episimmer.simulate.Simulate` object
Returns:
Callable function
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(ref: 'Simulate', *args, **kwargs) -> None:
func(ref, *args, **kwargs)
if ref.config_obj.worlds - 1 == Time.get_current_world():
args = parse_args()
if args.viz_dyn:
g = get_interaction_graph_from_object(ref)
ref.g_list.append(g)
return wrapper
return decorator
[docs]def set_ax_params(ax: Axes, model: 'BaseModel', time_step: int) -> Axes:
"""
Sets the title and legend of the Axes for the interaction graph
Args:
ax: Axes
model: Disease model used
time_step: Current time step
Returns:
Axes
"""
ax.set_title(f'Time step {time_step}', {'fontsize': 18})
for state in model.individual_state_types:
ax.scatter([0], [0], color=model.colors[state], label=state)
ax.scatter([0], [0], color='#40E0D0',
label='Event location') # Events - Turquoise
ax.scatter([0], [0], color='#FFFFFF')
ax.legend(bbox_to_anchor=(1, 1), prop={'size': 12})
return ax
[docs]def draw_graph(
g: nx.Graph, ax: Axes, seed: Union[str, int]
) -> Tuple[collections.PatchCollection, collections.LineCollection]:
"""
Sets the node positions and edges according to the spring layout and returns them.
Args:
g: Current interaction graph
ax: Axes
seed: Seed for consistent graph
Returns:
Nodes and Edges
"""
pos = nx.get_node_attributes(g, 'pos')
color = nx.get_node_attributes(g, 'color')
# Shuffling positions
# temp = list(pos.values())
# random.shuffle(temp)
# pos = dict(zip(pos, temp))
# Layout positions
pos = nx.spring_layout(g, seed=int(seed))
nodes = nx.draw_networkx_nodes(g, pos, node_color=color.values(), ax=ax)
edges = nx.draw_networkx_edges(g,
pos,
ax=ax,
connectionstyle='arc3, rad = 0.1')
return nodes, edges
[docs]def animate_graph(
time_step: int, fig: figure.Figure, model: 'BaseModel',
g_list: List[nx.Graph], seed: Union[str, int]
) -> Tuple[collections.PatchCollection, collections.LineCollection]:
"""
Sets up the figure to plot the current time step interaction network
Args:
time_step: Current time step
fig: Figure used to plot
model: Disease model used
g_list: List of interaction graphs for every time step
seed: Seed for consistent graph
Returns:
Nodes and edges for current time step
"""
if not seed:
seed = 42
fig.clf()
ax = fig.gca()
ax = set_ax_params(ax, model, time_step)
current_g = g_list[time_step % len(g_list)]
return draw_graph(current_g, ax, seed)
[docs]def store_animated_dynamic_graph() -> Callable:
"""
Decorator to store the evolving interactions graph as a gif.
Returns:
Callable function
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(ref: 'Simulate', *args, **kwargs) -> None:
if ref.config_obj.worlds - 1 == Time.get_current_world():
cmd_args = parse_args()
if cmd_args.viz_dyn:
fig = plt.figure()
fig.set_size_inches(20, 14)
anim = ani.FuncAnimation(
fig,
animate_graph,
frames=ref.config_obj.time_steps,
fargs=(fig, ref.model, ref.g_list,
ref.config_obj.random_seed))
anim.save(osp.join(ref.config_obj.example_path, 'results',
'dyn_graph.gif'),
writer=ani.PillowWriter(fps=5))
fig.clf()
return func(ref, *args, **kwargs)
return wrapper
return decorator