Benchmark Visual¶
Setup¶
[1]:
import ast
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
[2]:
# Define color palette
sns.set(
color_codes=True, palette="bright", style="white", context="talk", font_scale=1.5
)
[3]:
def load_result(filename):
"""
Loads results from specified file
"""
inputs = open(filename, "r")
lines = inputs.readlines()
ls = []
for line in lines:
ls.append(ast.literal_eval(line))
return ls
def plot_acc(col, ls, pos, n_train):
if pos == 0:
for i, l in enumerate(ls[pos]):
col.plot(n_train, l, label=legends[i], lw=5)
else:
for l in ls[pos]:
col.plot(n_train, l, lw=5)
[4]:
directory = "../benchmarks/results/"
prefixes = ["dt/", "rf/", "sdt/", "sdf/", "ht/", "mf/"]
legends = [
"decision tree",
"random forest",
"stream decision tree",
"stream decision forest",
"hoeffding tree",
"mondrian forest",
]
datasets = ["splice", "pendigits", "cifar10"]
ranges = [23, 74, 500]
Plot¶
[5]:
# Show concatenated time for batch estimators
concat = True
acc_ls = []
time_ls = []
for i, dataset in enumerate(datasets):
acc_l = []
time_l = []
for prefix in prefixes:
acc = np.mean(
load_result(directory + prefix + dataset + "_acc.txt")[:10], axis=0
)
acc_l.append(acc)
time = np.mean(
load_result(directory + prefix + dataset + "_train_t.txt")[:10], axis=0
)
if concat and (prefix == "dt/" or prefix == "rf/"):
for j in range(1, ranges[i]):
time[j] += time[j - 1]
time_l.append(time)
acc_ls.append(acc_l)
time_ls.append(time_l)
ls = acc_ls + time_ls
[6]:
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(17, 11), constrained_layout=True)
fig.text(0.53, -0.05, "Number of Train Samples", ha="center")
xtitles = ["Splice", "Pendigits", "CIFAR-10"]
ytitles = ["Accuracy", "Wall Time (s)"]
ylimits = [[0, 1], [1e-4, 1e5]]
yticks = [[0, 0.5, 1], [1e-4, 1e-1, 1e2, 1e5]]
for i, row in enumerate(ax):
for j, col in enumerate(row):
count = 3 * i + j
col.set_xscale("log")
col.set_ylim(ylimits[i])
n_train = range(100, (ranges[j] + 1) * 100, 100)
# Label x axis and plot figures
if count < 3:
col.set_xticks([])
col.set_title(xtitles[j])
plot_acc(col, ls, j, n_train)
else:
if count == 5:
col.set_xticks([1e2, 1e3, 1e4, 1e5])
else:
col.set_xticks([1e2, 1e3, 1e4])
col.set_yscale("log")
plot_acc(col, ls, j + 3, n_train)
# Label y axis
if count % 3 == 0:
col.set_yticks(yticks[i])
col.set_ylabel(ytitles[i])
else:
col.set_yticks([])
fig.align_ylabels(
ax[
:,
]
)
leg = fig.legend(
bbox_to_anchor=(0.53, -0.2),
bbox_transform=plt.gcf().transFigure,
ncol=3,
loc="lower center",
)
leg.get_frame().set_linewidth(0.0)
for legobj in leg.legendHandles:
legobj.set_linewidth(5.0)
plt.savefig("../paper/visual.pdf", transparent=True, bbox_inches="tight")
