From 9beda2231d1be938e43dd2655772427938b51175 Mon Sep 17 00:00:00 2001 From: Andreas Tsouchlos Date: Mon, 21 Nov 2022 13:28:00 +0100 Subject: [PATCH] Modified visualization code to wread metadata --- sw/plot_results.py | 48 ++++++++++++++++++------------------- sw/utility/visualization.py | 41 ++++++++++++++++--------------- 2 files changed, 45 insertions(+), 44 deletions(-) diff --git a/sw/plot_results.py b/sw/plot_results.py index 87b2486..b27842c 100644 --- a/sw/plot_results.py +++ b/sw/plot_results.py @@ -3,41 +3,39 @@ import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import os -from utility import visualization +from utility import visualization, simulation def plot_results(): - graph_names = {"96.3.965": "n=96, k=48 - 965", - "204.3.486": "n=204, k=102 - 486", - "204.55.187": "n=204, k=102 - 187", - "408.33.844": "n=408, k=204 - 844", - "816.1A4.845": "n=816, k=272 - 843", - "999.111.3.5543": "n=999, k=888 - 5543", - "999.111.3.5565": "n=999, k=888 - 5565", - "PEGReg252x504": "n=504, k=252 - PEGReg"} - line_labels = {"BER_ML": "ML", - "BER_prox_0_15": "$\gamma = 0.15$", - "BER_prox_0_05": "$\gamma = 0.05$", - "BER_prox_0_01": "$\gamma = 0.01$"} + sim_names = [ + "96.3965", + "204.3.486", + "204.55.187", + "408.33.844", + "816.1A4.845", + "999.111.3.5543", + "999.111.3.5565", + "PEGReg252x504" + ] - # Read data from files - results_dir = "sim_results" + deserializer = simulation.SimulationDeSerializer(save_dir="sim_saves", results_dir="sim_results") - data = {} - for file in os.listdir(results_dir): - if file.endswith(".csv"): - filename = os.path.splitext(file)[0] + data = [] + for sim_name in sim_names: + df, metadata = deserializer.read_results(sim_name) + df = df.loc[:, ~df.columns.str.contains('^Unnamed')] - df = pd.read_csv(os.path.join(results_dir, file)) - df = df.loc[:, ~df.columns.str.contains('^Unnamed')] - data[graph_names[filename]] = df + graph_title = sim_name + line_labels = metadata["labels"] - # Create and show graphs + graph_tuple = (graph_title, df, line_labels) + data.append(graph_tuple) sns.set_theme() - fig = visualization.show_BER_curves("Bit-Error-Rates of proximal decoder for different codes", - data, num_cols=4, line_labels=line_labels) + fig = visualization.plot_BERs(title="Bit-Error-Rates of proximal decoder for different codes", + data=data, num_cols=4) + plt.show() diff --git a/sw/utility/visualization.py b/sw/utility/visualization.py index 56d7b93..d855059 100644 --- a/sw/utility/visualization.py +++ b/sw/utility/visualization.py @@ -18,20 +18,22 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int: # TODO: Handle number of graphs not nicely fitting into rows and columns -def show_BER_curves(title: str, - data: typing.Dict[str, pd.DataFrame], - line_labels: typing.Dict[str, str], - num_cols: int = 3) -> plt.figure: - """This function creates a matplotlib figure containing a number of BER curves. +def plot_BERs(title: str, + data: typing.Sequence[typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]], + num_cols: int = 3) -> plt.figure: + """This function creates a matplotlib figure containing a number of plots. + + The plots created are logarithmic and the scaling is adjusted to be sensible for BER plots. :param title: Title of the figure - :param data: Dictionary where each key corresponds to the name of a new graph and the value is a pandas Dataframe - containing the data to be plotted. Each dataframe is assumed to contain a column named "SNR" which is used - as the x-axis - :param line_labels: Dictionary mapping column names to proper labels + :param data: Sequence of tuples. Each tuple corresponds to a new plot and + is of the following form: [graph_title, pd.Dataframe, [line_label_1, line_label2, ...]]. + Each dataframe is assumed to have an "SNR" column that is used as the x axis. :param num_cols: Number of columns in which the graphs should be arranged in the resulting figure :return: Matplotlib figure """ + # Determine layout and create figure + num_graphs = len(data) num_rows = _get_num_rows(num_graphs, num_cols) @@ -47,18 +49,19 @@ def show_BER_curves(title: str, axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array - for axis, name_data_pair in zip(axes, sorted(data.items())): - graph_name, df = name_data_pair + # Populate axes + + for axis, (graph_title, df, labels) in zip(axes, data): column_names = [column for column in df.columns.values.tolist() if not column == "SNR"] - for column in column_names: - sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=line_labels[column]) + for column, label in zip(column_names, labels): + sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=label) - axis.set_title(graph_name) - axis.set(yscale="log") - axis.set_xlabel("SNR") - axis.set_ylabel("BER") - axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0]) - axis.legend() + axis.set_title(graph_title) + axis.set(yscale="log") + axis.set_xlabel("SNR") + axis.set_ylabel("BER") + axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0]) + axis.legend() return fig