Source code for ai4materials.utils.utils_plotting

# coding=utf-8
# Copyright 2016-2018 Angelo Ziletti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

__author__ = "Angelo Ziletti"
__copyright__ = "Copyright 2018, Angelo Ziletti"
__maintainer__ = "Angelo Ziletti"
__email__ = "ziletti@fhi-berlin.mpg.de"
__date__ = "23/09/18"

import itertools

import logging
import pandas as pd
import os
import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
os.system("export DISPLAY=:0")
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import numpy as np
import os
import matplotlib.cm as cm
from matplotlib.pyplot import figure, show, axes, sci
from matplotlib import cm, colors
from matplotlib.font_manager import FontProperties
from numpy import amin, amax, ravel
from matplotlib.colors import LinearSegmentedColormap

import tensorflow as tf

# tf.set_random_seed(0) # for tf<1
tf.random.set_seed(0)

logger = logging.getLogger('ai4materials')


[docs]def insert_newlines(string, every=64): return '\n'.join(string[i:i + every] for i in range(0, len(string), every))
[docs]def plot_sph_harmonics(): # http://docs.enthought.com/mayavi/mayavi/auto/example_spherical_harmonics.html from mayavi import mlab import numpy as np from scipy.special import sph_harm # Create a sphere r = 0.3 pi = np.pi cos = np.cos sin = np.sin phi, theta = np.mgrid[0:pi:101j, 0:2 * pi:101j] x = r * sin(phi) * cos(theta) y = r * sin(phi) * sin(theta) z = r * cos(phi) mlab.figure(1, bgcolor=(1, 1, 1), fgcolor=(0, 0, 0), size=(400, 300)) mlab.clf() # Represent spherical harmonics on the surface of the sphere for n in range(1, 6): for m in range(n): s = sph_harm(m, n, theta, phi).real mlab.mesh(x - m, y - n, z, scalars=s, colormap='jet') s[s < 0] *= 0.97 s /= s.max() mlab.mesh(s * x - m, s * y - n, s * z + 1.3, scalars=s, colormap='Spectral') mlab.view(90, 70, 6.2, (-1.3, -2.9, 0.25)) mlab.show()
[docs]def plot_save_cnn_results(filename, accuracy=True, cross_entropy_loss=True, show_plot=False): """Plot and save results of a convolutional neural network calculation from the .csv file written by Keras CSVLogger. .. codeauthor:: Angelo Ziletti <angelo.ziletti@gmail.com> """ df_results = pd.read_csv(filename) plt.style.use('fivethirtyeight') epoch = df_results.epoch.values + 1 if accuracy: a_tr = df_results.acc.values * 100.0 a_val = df_results.val_acc.values * 100.0 # a_test = df_results.val_acc.values*100.0 if cross_entropy_loss: c_tr = df_results.loss.values * 100.0 c_val = df_results.val_loss.values * 100.0 # c_test = df_results.val_loss.values*100.0 if accuracy: figure_a = make_plot_accuracy(epoch, a_tr, a_val) # save png file (same name as csv file, but with png extension) figure_a.savefig(filename.rsplit('.', 1)[0] + '_accuracy.png', format="png") if show_plot: figure_a.show() if cross_entropy_loss: figure_c = make_plot_cross_entropy_loss(epoch, c_tr, c_val) # save png file (same name as csv file, but with png extension) figure_c.savefig(filename.rsplit('.', 1)[0] + '_cross_entropy_loss.png', format="png") if show_plot: figure_c.show()
[docs]def make_plot_accuracy(step, train_data, val_data): # add mask to have line between missing values train_data_mask = np.isfinite(train_data) val_data_mask = np.isfinite(val_data) f, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=False, figsize=(13, 10)) plt.suptitle("Convolutional neural network: model accuracy", fontname='Ubuntu', fontsize=24, fontstyle='italic', fontweight='bold') plt.tight_layout(pad=4.0, w_pad=2.0, h_pad=3.0) plt.grid(True) ax1.set_xlim([-np.amax(step) * 0.01 + 1.0, np.amax(step) * 1.01]) ax1.set_ylim([0, 105.0]) start, end = ax1.get_xlim() ax1.xaxis.set_ticks(np.arange(min(step), max(step) + 1, 1)) ax1.plot(step[train_data_mask], train_data[train_data_mask], 'ro-', label='Training accuracy') ax1.plot(step[val_data_mask], val_data[val_data_mask], 'go-', label='Validation accuracy') ax1.set_xlabel('Epoch number') ax1.set_ylabel('Accuracy [%]') ax1.set_axis_bgcolor((224 / 255, 224 / 255, 224 / 255)) legend = ax1.legend(loc='lower right', borderaxespad=0., frameon=1) for text in legend.get_texts(): plt.setp(text, color=(224 / 255, 224 / 255, 224 / 255)) frame = legend.get_frame() frame.set_facecolor((32 / 255, 32 / 255, 32 / 255)) frame.set_edgecolor((32 / 255, 32 / 255, 32 / 255)) ax2.set_xlim([-np.amax(step) * 0.01 + 1.0, np.amax(step) * 1.01]) ax2.set_ylim([95, 100.5]) ax2.xaxis.set_ticks(np.arange(min(step), max(step) + 1, 1)) ax2.plot(step[train_data_mask], train_data[train_data_mask], 'ro-', label='Training accuracy') ax2.plot(step[val_data_mask], val_data[val_data_mask], 'go-', label='Validation accuracy') ax2.set_xlabel('Epoch number') ax2.set_ylabel('Accuracy [%]') ax2.set_axis_bgcolor((224 / 255, 224 / 255, 224 / 255)) legend = ax2.legend(loc='lower right', borderaxespad=0., frameon=1) for text in legend.get_texts(): plt.setp(text, color=(224 / 255, 224 / 255, 224 / 255)) frame = legend.get_frame() frame.set_facecolor((32 / 255, 32 / 255, 32 / 255)) frame.set_edgecolor((32 / 255, 32 / 255, 32 / 255)) return plt
[docs]def make_plot_cross_entropy_loss(step, train_data, val_data, title=None): # add mask to have line between missing values train_data_mask = np.isfinite(train_data) val_data_mask = np.isfinite(val_data) f, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=False, figsize=(13, 10)) plt.suptitle("Convolutional neural network: cross-entropy loss", fontname='Ubuntu', fontsize=24, fontstyle='italic', fontweight='bold') plt.tight_layout(pad=4.0, w_pad=2.0, h_pad=3.0) plt.grid(True) ax1.set_xlim([-np.amax(step) * 0.01 + 1.0, np.amax(step) * 1.01]) min_value = min(np.nanmin(train_data), np.nanmin(val_data)) max_value = max(np.nanmax(train_data), np.nanmax(val_data)) ax1.set_ylim([-max_value * 0.05, max_value * 1.03]) start, end = ax1.get_xlim() ax1.xaxis.set_ticks(np.arange(min(step), max(step) + 1, 1)) ax1.plot(step[train_data_mask], train_data[train_data_mask], 'ro-', label='Training accuracy') ax1.plot(step[val_data_mask], val_data[val_data_mask], 'go-', label='Validation accuracy') ax1.set_xlabel('Epoch number') ax1.set_ylabel('Cross entropy loss') ax1.set_axis_bgcolor((224 / 255, 224 / 255, 224 / 255)) legend = ax1.legend(loc='upper right', borderaxespad=0., frameon=1) for text in legend.get_texts(): plt.setp(text, color=(224 / 255, 224 / 255, 224 / 255)) frame = legend.get_frame() frame.set_facecolor((32 / 255, 32 / 255, 32 / 255)) frame.set_edgecolor((32 / 255, 32 / 255, 32 / 255)) ax2.set_xlim([-np.amax(step) * 0.01 + 1.0, np.amax(step) * 1.01]) min_value = min(np.nanmin(train_data), np.nanmin(val_data)) max_value = max(np.nanmax(train_data), np.nanmax(val_data)) ax2.set_ylim([-1.0 + min_value * 0.97, max_value * 0.20]) ax2.xaxis.set_ticks(np.arange(min(step), max(step) + 1, 1)) ax2.plot(step[train_data_mask], train_data[train_data_mask], 'ro-', label='Training accuracy') ax2.plot(step[val_data_mask], val_data[val_data_mask], 'go-', label='Validation accuracy') ax2.set_xlabel('Epoch number') ax2.set_ylabel('Cross entropy loss') ax2.set_axis_bgcolor((224 / 255, 224 / 255, 224 / 255)) legend = ax2.legend(loc='upper right', borderaxespad=0., frameon=1) for text in legend.get_texts(): plt.setp(text, color=(224 / 255, 224 / 255, 224 / 255)) frame = legend.get_frame() frame.set_facecolor((32 / 255, 32 / 255, 32 / 255)) frame.set_edgecolor((32 / 255, 32 / 255, 32 / 255)) return plt
[docs]def aggregate_struct_trans_data(filename, nb_rows_to_cut=0, nb_samples=None, nb_order_param_steps=None, min_order_param=0.0, max_order_param=None, prob_idxs=None, with_uncertainty=True, uncertainty_types=('variation_ratio', 'predictive_entropy', 'mutual_information')): """ Aggregate structural transition data in order to plot it later. Starting from the results_file of the run_cnn_model function, aggregate the data by a given order parameter and the probabilities of each class. This is used to prepare the data for the structural transition plots, as shown in Fig. 4, Ziletti et al., Nature Communications 9, 2775 (2018). Parameters: filename: string, Full path to the results_file created by the run_cnn_model function. This is a csv file nb_samples: int Number of samples present in results_file for each order parameter step. nb_order_param_steps: int Number of order parameter steps. For example, if we are interpolating between structure_1 and structure_2 with 10 steps, nb_order_param_steps=10. max_order_param: float Maximum number that the order parameter will take in the dataset. This is used to create (together with nb_order_param_steps) to create the linear space which will be later used by the plotting function. prob_idxs: list of int List of integers which correspond to the classes for which the probabilities will be extracted from the results_file. prob_idxs=[0, 3] will extract only prob_predictions_0 and prob_predictions_3 from the results_file. Returns: panda dataframe A panda dataframe with the following columns: - a_to_b_index_ : value of the order parameter - 2i columns (where the i's are the elements of the list prob_idxs) as below: prob_predictions_i_mean : mean of the distribution of classification probability i for the given a_to_b_index_ value of the order parameter. prob_predictions_i_std : standard deviation of the distribution of classification probability i for the given a_to_b_index_ value of the order parameter. - [optional]: columns containing uncertainty quantification .. codeauthor:: Angelo Ziletti <angelo.ziletti@gmail.com> """ df = pd.read_csv(filename) # throw away first 'nb_rows_to_cut' rows because they come from descriptor_all_classes_8_samples.tar.gz # it is a workaround to have the neural network to predict even if not # all classes are present in the dataset df = df[nb_rows_to_cut:] # nb samples for each order parameter steps steps, step = np.linspace(min_order_param, max_order_param, nb_order_param_steps, retstep=True) a_to_b_index = np.repeat(steps, nb_samples) df['a_to_b_index'] = a_to_b_index prob_predictions = [] prob_pred_agg = {} for prob_idx in prob_idxs: prob_prediction = 'prob_predictions_' + str(prob_idx) prob_predictions.append(prob_prediction) prob_pred_agg.update({prob_prediction: ['mean', 'std']}) df_results_prob = df.groupby(['a_to_b_index'], as_index=False).agg(prob_pred_agg) # flatten hierarchical index # NB: you cannot just rename the columns # the values are ordered by increasing mean, so the column name --> value # will not be conserved df_results_prob.columns = ['_'.join(col).strip() for col in df_results_prob.columns.values] df_results_prob.reindex(columns=sorted(df_results_prob.columns)) if with_uncertainty: uncertainty_preds = [] uncertainty_pred_agg = {} for uncertainty_type in uncertainty_types: uncertainty_pred = 'uncertainty_' + str(uncertainty_type) uncertainty_preds.append(uncertainty_pred) uncertainty_pred_agg.update({uncertainty_pred: ['mean', 'std']}) df_results_uncertainty = df.groupby(['a_to_b_index'], as_index=False).agg(uncertainty_pred_agg) df_results_uncertainty.columns = ['_'.join(col).strip() for col in df_results_uncertainty.columns.values] df_results_uncertainty.reindex(columns=sorted(df_results_uncertainty.columns)) # df_results_uncertainty.drop('a_to_b_index_', axis=1, inplace=True) if with_uncertainty: # merge the probability prediction results with the uncertainty results df_results = pd.merge(df_results_prob, df_results_uncertainty, on='a_to_b_index_') else: df_results = df_results_prob return df_results
[docs]def make_crossover_plot(df_results, filename, filename_suffix, title, labels, nb_order_param_steps, plot_type='probability', prob_idxs=None, uncertainty_type='mutual_information', linewidth=1.0, markersize=1.0, max_nb_ticks=None, palette=None, show_plot=False, style='publication', x_label="Order parameter"): """ Starting from an aggregated data panda dataframe, plot classification probability distributions as a function of an order parameter. This will produce a plot along the lines of Fig. 4, Ziletti et al. Parameters: df_results: panda dataframe, Panda dataframe returned by the `aggregate_struct_trans_data` function. filename: string Full path to the results_file created by the run_cnn_model function. This is a csv file. Only used to name the generated plot appriately. filename_suffix: string Suffix to be put for the plot filename. This suffix will determine the format of the output plot (e.g. '.png' or '.svg' will create a png or an svg file, respectively.) title: string Title of the plot plot_type: str (options: 'probability', 'uncertainty') Plot either probabilities of classification or uncertainty. uncertainty_type: str (options: 'mutual_information', 'predictive_entropy') Type of uncertainty estimation to be plotted. Used only if `plot_type`='uncertainty'. prob_idxs: list of int List of integers which correspond to the classes for which the probabilities will be extracted from the results_file. prob_idxs=[0, 3] will extract only prob_predictions_0 and prob_predictions_3 from the results_file. They should correspond (or be a subset) of the prob_idxs specified in aggregate_struct_trans_data. nb_order_param_steps: int Number of order parameter steps. For example, if we are interpolating between structure_1 and structure_2 with 10 steps, nb_order_param_steps=10. Must be the same as specified in aggregate_struct_trans_data. Different values might work, but could give rise to unexpected behaviour. show_plot: bool, optional, default: False If True, it opens the generated plot. style: string, optional, {'publication'} If style=='publication', load the default matplotlib style (white background). Otherwise, use the 'fivethirtyeight' matplotlib style (black background). plt.style.use('fivethirtyeight') x_label: string, optional, default: "Order parameter" Label for the x-axis (the order parameter axis) .. codeauthor:: Angelo Ziletti <angelo.ziletti@gmail.com> """ if style == 'publication': plt.style.use('default') else: plt.style.use('fivethirtyeight') # colors from https://matplotlib.org/examples/color/named_colors.html if palette is None: palette = ['yellow', 'red', 'blue', 'green', 'purple', 'orange', 'black'] a_to_b_param = df_results.a_to_b_index_.values colors_plot = [] labels_sel = [] y_label_name_mean = [] y_label_name_std = [] if plot_type == 'probability': for prob_idx in prob_idxs: y_label_name_mean.append('prob_predictions_' + str(prob_idx) + '_mean') y_label_name_std.append('prob_predictions_' + str(prob_idx) + '_std') colors_plot.append(palette[prob_idx]) labels_sel.append(labels[prob_idx]) elif plot_type == 'uncertainty': y_label_name_mean.append('uncertainty_' + str(uncertainty_type) + '_mean') y_label_name_std.append('uncertainty_' + str(uncertainty_type) + '_std') colors_plot.append(palette[0]) labels_sel.append(labels[0]) else: raise Exception("Please specify a valid plot_type. Possible values are: 'probability', 'uncertainty'.") y_value_mean = [] y_value_std = [] if plot_type == 'probability': # a is 1st prob_idx, b is 2nd (order matters for the plot) for prob_idx in range(len(prob_idxs)): y_value_mean.append(df_results[y_label_name_mean[prob_idx]].values) y_value_std.append(df_results[y_label_name_std[prob_idx]].values) elif plot_type == 'uncertainty': y_value_mean.append(df_results[y_label_name_mean].values) y_value_std.append(df_results[y_label_name_std].values) else: pass # set max nb ticks if max_nb_ticks is not None: max_nb_ticks = min(max_nb_ticks, nb_order_param_steps) else: max_nb_ticks = nb_order_param_steps steps, step = np.linspace(np.amin(a_to_b_param), np.amax(a_to_b_param), max_nb_ticks, retstep=True) # the sigma/STD_SCALING upper and lower analytic population bounds std_scaling = 1.0 lower_bound = [] upper_bound = [] if plot_type == 'probability': for prob_idx in range(len(prob_idxs)): lower_bound.append(y_value_mean[prob_idx] - y_value_std[prob_idx] / std_scaling) upper_bound.append(y_value_mean[prob_idx] + y_value_std[prob_idx] / std_scaling) elif plot_type == 'uncertainty': lower_bound.append(y_value_mean[0] - y_value_std[0] / std_scaling) upper_bound.append(y_value_mean[0] + y_value_std[0] / std_scaling) else: pass fig, ax = plt.subplots(1) plt.suptitle(title, fontname='Ubuntu', fontsize=15, fontstyle='italic', fontweight='bold') plt.tight_layout(pad=5.0, w_pad=2.0, h_pad=1.0) # restore defaults to 1.5.1 for reproducibility # https: // matplotlib.org / users / dflt_style_changes.html # grid-lines plt.grid(True, color='gray', linestyle='--', linewidth=0.5) # ax.set_xlim([-np.amax(a_to_b_param) * 0.05 + np.amin(a_to_b_param), np.amax(a_to_b_param) * 1.05]) ax.set_xlim([np.amin(a_to_b_param), np.amax(a_to_b_param)]) if plot_type == 'probability': ax.set_ylim([-0.1, 1.1]) start, end = ax.get_xlim() ax.xaxis.set_ticks(steps) for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(2) # specify integer or one of preset strings, e.g. # tick.label.set_fontsize('x-small') tick.label.set_rotation('vertical') if plot_type == 'probability': for prob_idx in range(len(prob_idxs)): ax.plot(a_to_b_param, y_value_mean[prob_idx], marker='o', linestyle='-', color=colors_plot[prob_idx], label=labels_sel[prob_idx], linewidth=linewidth, markeredgecolor=colors_plot[prob_idx], markersize=markersize) ax.fill_between(a_to_b_param, lower_bound[prob_idx], upper_bound[prob_idx], facecolor=colors_plot[prob_idx], alpha=0.2, edgecolor=colors_plot[prob_idx], linewidth=0.0) elif plot_type == 'uncertainty': ax.plot(a_to_b_param, y_value_mean[0], marker='o', linestyle='-', color=colors_plot[0], label=labels_sel[0], linewidth=linewidth, markeredgecolor=colors_plot[0], markersize=markersize) ax.fill_between(a_to_b_param, np.array(lower_bound).reshape(-1), np.array(upper_bound).reshape(-1) , facecolor=colors_plot[0], alpha=0.2, edgecolor=colors_plot[0], linewidth=0.0) else: pass ax.set_xlabel(x_label, fontsize=15) if plot_type == 'probability': ax.set_ylabel("Classification probability", fontsize=15) elif plot_type == 'uncertainty': if uncertainty_type == 'mutual_information': ax.set_ylabel("Mutual information", fontsize=15) elif uncertainty_type == 'predictive_entropy': ax.set_ylabel("Predictive entropy", fontsize=15) else: ax.set_ylabel("Label", fontsize=15) ax.tick_params(labelsize=15) legend = ax.legend(loc='center left', fontsize=10, bbox_to_anchor=(0.1, 0.5), borderaxespad=1.0, frameon=1) if style == 'publication': for text in legend.get_texts(): plt.setp(text, color=(0 / 255, 0 / 255, 0 / 255)) else: for text in legend.get_texts(): plt.setp(text, color=(224 / 255, 224 / 255, 224 / 255)) ax.set_axis_bgcolor((224 / 255, 224 / 255, 224 / 255)) frame = legend.get_frame() frame.set_facecolor((32 / 255, 32 / 255, 32 / 255)) frame.set_edgecolor((32 / 255, 32 / 255, 32 / 255)) if filename_suffix == ".png": plt.savefig(filename.rsplit('.', 1)[0] + '_' + plot_type + filename_suffix, format="png") elif filename_suffix == ".svg": plt.savefig(filename.rsplit('.', 1)[0] + '_' + plot_type + filename_suffix, format="svg") else: raise Exception("Filename suffix {0} is not a valid file format.".format(filename_suffix)) if show_plot: plt.show()
[docs]def show_images(images, filename_png, cols=1, titles=None): """Display a list of images in a single figure with matplotlib. Taken from https://stackoverflow.com/questions/11159436/multiple-figures-in-a-single-window Parameters: images: list of np.arrays Images to be plotted. It must be compatible with plt.imshow. cols: int, optional, (default = 1) Number of columns in figure (number of rows is set to np.ceil(n_images/float(cols))). titles: list of strings List of titles corresponding to each image. """ plt.clf() assert ((titles is None) or (len(images) == len(titles))) n_images = len(images) if titles is None: titles = ['Image (%d)' % i for i in range(1, n_images + 1)] fig = plt.figure() for n, (image, title) in enumerate(zip(images, titles)): a = fig.add_subplot(cols, np.ceil(n_images / float(cols)), n + 1) plt.imshow(image, interpolation='spline16', cmap='viridis', vmin=np.amin(images), vmax=np.amax(images)) a.set_title(title) fig.set_size_inches(np.array(fig.get_size_inches()) * n_images) plt.savefig(filename_png, dpi=10, format='png')
[docs]def make_multiple_image_plot(data, title="Figure 1", cmap=cm.hot, n_rows=None, n_cols=None, vmin=None, vmax=None, filename=None, save=False): fig = plt.figure() plt.suptitle(title, fontname='Ubuntu', fontsize=15, fontstyle='italic', fontweight='bold') plt.style.use('fivethirtyeight') margin = 0.08 w = (1.0 - margin * 2) / n_cols h = (1.0 - margin * 2) / n_rows nb_channels = data.shape[1] cmaps = [] if nb_channels == 3: # define colormaps # from 0 to full red, green and blue colors_for_maps = ["red", "green", "blue"] for color_for_maps in colors_for_maps: cmaps.append(rgb_colormaps(color_for_maps)) elif nb_channels == 1: cmaps.append(cmap) else: raise Exception("Unexpected number of color channels: {}".format(nb_channels)) filenames_ch = [] for idx_ch in range(nb_channels): images = [] idx_filter = 0 for i in range(n_cols): for j in range(n_rows): if idx_filter < data.shape[0]: # https://python4astronomers.github.io/plotting/advanced.html # bottom first # pos = [margin + i*1.0*w, margin + j*1.0*h, w, h] # top first pos = [margin + i * 1.0 * w, (1.0 - j * 1.0 * h - h - margin), w, h] a = fig.add_axes(pos) data_filter = data[idx_filter, idx_ch, :, :] dd = ravel(data_filter) # Manually find the min and max of all colors for # use in setting the color scale. vmin = min(vmin, amin(dd)) # make sure vmin is positive or zero vmin = max(0.0, vmin) # stretches the images to the desired width images.append(a.imshow(data_filter, cmap=cmaps[idx_ch], vmin=vmin, vmax=vmax)) # do not show axis plt.axis('off') idx_filter += 1 # split filename to remove path from extension filename_no_ext, file_extension = os.path.splitext(filename) filename_ch = filename_no_ext + "_ch" + str(idx_ch) + file_extension filenames_ch.append(filename_ch) if save: logger.info("Saving multiple image plot to file.") logger.debug("Filename: {0}".format(filename)) plt.savefig(filename_ch, dpi=600, format="png") plt.clf() return filenames_ch
[docs]def rgb_colormaps(color): """Obtain colormaps for RGB. For a general overview: https://matplotlib.org/examples/pylab_examples/custom_cmap.html""" if color == "red": cdict = {'red': ((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), 'green': ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0)), 'blue': ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0))} elif color == "green": cdict = {'red': ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0)), 'green': ((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), 'blue': ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0))} elif color == "blue": cdict = {'red': ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0)), 'green': ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0)), 'blue': ((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))} else: raise ValueError("Wrong color specified. Allowed colors are 'red', 'green', 'blue'.") cmap = LinearSegmentedColormap('BlueRed2', cdict) return cmap
[docs]def plot_confusion_matrix(conf_matrix, classes, conf_matrix_file, normalize=False, title='Confusion matrix', title_true_label='True label', title_pred_label='Predicted label', cmap='Blues'): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ if normalize: conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] logger.debug("Normalized confusion matrix") else: logger.debug('Confusion matrix, without normalization') fig = plt.figure() plt.imshow(conf_matrix, interpolation='none', cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45) plt.yticks(tick_marks, classes) fmt = '.2f' if normalize else 'd' thresh = conf_matrix.max() / 2. for i, j in itertools.product(range(conf_matrix.shape[0]), range(conf_matrix.shape[1])): plt.text(j, i, format(conf_matrix[i, j], fmt), horizontalalignment="center", color="white" if conf_matrix[i, j] > thresh else "black") plt.tight_layout() # add this otherwise the x-axis gets cut plt.gcf().subplots_adjust(bottom=0.25) plt.ylabel(title_true_label) plt.xlabel(title_pred_label) # plt.show() plt.savefig(conf_matrix_file, dpi=100, format="png") plt.clf()