1"""
2
3"""
4
5import matplotlib.pyplot as plt
6import numpy as np
7import seaborn as sns
8from matplotlib.collections import CircleCollection
9from sklearn.ensemble import AdaBoostClassifier
10from sklearn.tree import DecisionTreeClassifier
11
12from ema_workbench import load_results, ema_logging
13from ema_workbench.analysis import feature_scoring
14
15ema_logging.log_to_stderr(ema_logging.INFO)
16
17
18def plot_factormap(x1, x2, ax, bdt, nominal):
19 """helper function for plotting a 2d factor map"""
20 x_min, x_max = x[:, x1].min(), x[:, x1].max()
21 y_min, y_max = x[:, x2].min(), x[:, x2].max()
22 xx, yy = np.meshgrid(np.linspace(x_min, x_max, 500), np.linspace(y_min, y_max, 500))
23
24 grid = np.ones((xx.ravel().shape[0], x.shape[1])) * nominal
25 grid[:, x1] = xx.ravel()
26 grid[:, x2] = yy.ravel()
27
28 Z = bdt.predict(grid)
29 Z = Z.reshape(xx.shape)
30
31 ax.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.5) # @UndefinedVariable
32
33 for i in (0, 1):
34 idx = y == i
35 ax.scatter(x[idx, x1], x[idx, x2], s=5)
36 ax.set_xlabel(columns[x1])
37 ax.set_ylabel(columns[x2])
38
39
40def plot_diag(x1, ax):
41 x_min, x_max = x[:, x1].min(), x[:, x1].max()
42 for i in (0, 1):
43 idx = y == i
44 ax.hist(x[idx, x1], range=(x_min, x_max), alpha=0.5)
45
46
47# load data
48experiments, outcomes = load_results("./data/1000 flu cases with policies.tar.gz")
49
50# transform to numpy array with proper recoding of cateogorical variables
51x, columns = feature_scoring._prepare_experiments(experiments)
52y = outcomes["deceased_population_region 1"][:, -1] > 1000000
53
54# establish mean case for factor maps
55# this is questionable in particular in case of categorical dimensions
56minima = x.min(axis=0)
57maxima = x.max(axis=0)
58nominal = minima + (maxima - minima) / 2
59
60# fit the boosted tree
61bdt = AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), algorithm="SAMME", n_estimators=200)
62bdt.fit(x, y)
63
64# determine which dimensions are most important
65sorted_indices = np.argsort(bdt.feature_importances_)[::-1]
66
67# do the actual plotting
68# this is a quick hack, tying it to seaborn Pairgrid is probably
69# the more elegant solution, but is tricky with what arguments
70# can be passed to the plotting function
71fig, axes = plt.subplots(ncols=5, nrows=5, figsize=(15, 15))
72
73for i, row in enumerate(axes):
74 for j, ax in enumerate(row):
75 if i > j:
76 plot_factormap(sorted_indices[j], sorted_indices[i], ax, bdt, nominal)
77 elif i == j:
78 plot_diag(sorted_indices[j], ax)
79 else:
80 ax.set_xticks([])
81 ax.set_yticks([])
82 ax.axis("off")
83
84 if j > 0:
85 ax.set_yticklabels([])
86 ax.set_ylabel("")
87 if i < len(axes) - 1:
88 ax.set_xticklabels([])
89 ax.set_xlabel("")
90
91# add the legend
92# Draw a full-figure legend outside the grid
93handles = [
94 CircleCollection([10], color=sns.color_palette()[0]),
95 CircleCollection([10], color=sns.color_palette()[1]),
96]
97
98legend = fig.legend(handles, ["False", "True"], scatterpoints=1)
99
100plt.tight_layout()
101plt.show()