From d9009970ad8dde0fcaf3afe9452e8e7630ea18d8 Mon Sep 17 00:00:00 2001 From: Andreas Tsouchlos Date: Tue, 15 Nov 2022 16:09:57 +0100 Subject: [PATCH] Implemented proper titles and line labels in visualization.show_BER_curves() --- sw/plot_results.py | 46 +++++++++++++++++++++++++++++++++++++ sw/utility/visualization.py | 24 +++++++++++-------- 2 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 sw/plot_results.py diff --git a/sw/plot_results.py b/sw/plot_results.py new file mode 100644 index 0000000..838789c --- /dev/null +++ b/sw/plot_results.py @@ -0,0 +1,46 @@ +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import os +from utility import visualization + + +# TODO: Fix spacing between axes and margins +def plot_results(): + graph_names = {"96.3.965": "n=96, k=48 - 965", + "204.3.486": "n=204, k=102 - 486", + "204.55.187": "n=204, k=102 - 187", + "408.33.844": "n=408, k=204 - 844", + "816.1A4.845": "n=816, k=272 - 843", + "999.111.3.5543": "n=999, k=888 - 5543", + "999.111.3.5565": "n=999, k=888 - 5565", + "PEGReg252x504": "n=504, k=252 - PEGReg"} + + line_labels = {"BER_ML": "ML", + "BER_prox_0_15": "$\gamma = 0.15$", + "BER_prox_0_05": "$\gamma = 0.05$", + "BER_prox_0_01": "$\gamma = 0.01$"} + + # Read data from files + results_dir = "sim_results" + + data = {} + for file in os.listdir(results_dir): + if file.endswith(".csv"): + filename = os.path.splitext(file)[0] + + df = pd.read_csv(os.path.join(results_dir, file)) + df = df.loc[:, ~df.columns.str.contains('^Unnamed')] + data[graph_names[filename]] = df + + # Create and show graphs + + sns.set_theme() + fig = visualization.show_BER_curves("Bit-Error-Rates of proximal decoder for different codes", + data, num_cols=4, line_labels=line_labels) + plt.show() + + +if __name__ == "__main__": + plot_results() diff --git a/sw/utility/visualization.py b/sw/utility/visualization.py index c63e9fa..15ca7c3 100644 --- a/sw/utility/visualization.py +++ b/sw/utility/visualization.py @@ -18,15 +18,18 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int: # TODO: Calculate fig size in relation to the number of rows and columns -# TODO: Set proper line labels -# TODO: Set proper axis titles -# TODO: Should unnamed columns be dropped by this function or by the caller? # TODO: Handle number of graphs not nicely fitting into rows and columns -def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.figure: +def show_BER_curves(title: str, + data: typing.Dict[str, pd.DataFrame], + line_labels: typing.Dict[str, str], + num_cols: int = 3) -> plt.figure: """This function creates a matplotlib figure containing a number of BER curves. - :param data: List of pandas DataFrames containing the data to be plotted. Each dataframe in the list is plotted - in a new graph. Each dataframe is assumed to contain a column named "SNR" which is used as the x-axis + :param title: Title of the figure + :param data: Dictionary where each key corresponds to the name of a new graph and the value is a pandas Dataframe + containing the data to be plotted. Each dataframe is assumed to contain a column named "SNR" which is used + as the x-axis + :param line_labels: Dictionary mapping column names to proper labels :param num_cols: Number of columns in which the graphs should be arranged in the resulting figure :return: Matplotlib figure """ @@ -34,17 +37,18 @@ def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.f num_rows = _get_num_rows(num_graphs, num_cols) fig, axes = plt.subplots(num_rows, num_cols, squeeze=False) - fig.suptitle("Bit-Error-Rates of various decoders for different codes") + fig.suptitle(title) axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array - for axis, df in zip(axes, data): + for axis, name_data_pair in zip(axes, sorted(data.items())): + graph_name, df = name_data_pair column_names = [column for column in df.columns.values.tolist() if not column == "SNR"] for column in column_names: - sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=column) + sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=line_labels[column]) - #axis.set_title(code) + axis.set_title(graph_name) axis.set(yscale="log") axis.set_xlabel("SNR") axis.set_ylabel("BER")