Modified visualization code to wread metadata

This commit is contained in:
Andreas Tsouchlos 2022-11-21 13:28:00 +01:00
parent d0ed9ffbaa
commit 9beda2231d
2 changed files with 45 additions and 44 deletions

View File

@ -3,41 +3,39 @@ import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
import pandas as pd import pandas as pd
import os import os
from utility import visualization from utility import visualization, simulation
def plot_results(): def plot_results():
graph_names = {"96.3.965": "n=96, k=48 - 965",
"204.3.486": "n=204, k=102 - 486",
"204.55.187": "n=204, k=102 - 187",
"408.33.844": "n=408, k=204 - 844",
"816.1A4.845": "n=816, k=272 - 843",
"999.111.3.5543": "n=999, k=888 - 5543",
"999.111.3.5565": "n=999, k=888 - 5565",
"PEGReg252x504": "n=504, k=252 - PEGReg"}
line_labels = {"BER_ML": "ML", sim_names = [
"BER_prox_0_15": "$\gamma = 0.15$", "96.3965",
"BER_prox_0_05": "$\gamma = 0.05$", "204.3.486",
"BER_prox_0_01": "$\gamma = 0.01$"} "204.55.187",
"408.33.844",
"816.1A4.845",
"999.111.3.5543",
"999.111.3.5565",
"PEGReg252x504"
]
# Read data from files deserializer = simulation.SimulationDeSerializer(save_dir="sim_saves", results_dir="sim_results")
results_dir = "sim_results"
data = {} data = []
for file in os.listdir(results_dir): for sim_name in sim_names:
if file.endswith(".csv"): df, metadata = deserializer.read_results(sim_name)
filename = os.path.splitext(file)[0] df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
df = pd.read_csv(os.path.join(results_dir, file)) graph_title = sim_name
df = df.loc[:, ~df.columns.str.contains('^Unnamed')] line_labels = metadata["labels"]
data[graph_names[filename]] = df
# Create and show graphs graph_tuple = (graph_title, df, line_labels)
data.append(graph_tuple)
sns.set_theme() sns.set_theme()
fig = visualization.show_BER_curves("Bit-Error-Rates of proximal decoder for different codes", fig = visualization.plot_BERs(title="Bit-Error-Rates of proximal decoder for different codes",
data, num_cols=4, line_labels=line_labels) data=data, num_cols=4)
plt.show() plt.show()

View File

@ -18,20 +18,22 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int:
# TODO: Handle number of graphs not nicely fitting into rows and columns # TODO: Handle number of graphs not nicely fitting into rows and columns
def show_BER_curves(title: str, def plot_BERs(title: str,
data: typing.Dict[str, pd.DataFrame], data: typing.Sequence[typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]],
line_labels: typing.Dict[str, str], num_cols: int = 3) -> plt.figure:
num_cols: int = 3) -> plt.figure: """This function creates a matplotlib figure containing a number of plots.
"""This function creates a matplotlib figure containing a number of BER curves.
The plots created are logarithmic and the scaling is adjusted to be sensible for BER plots.
:param title: Title of the figure :param title: Title of the figure
:param data: Dictionary where each key corresponds to the name of a new graph and the value is a pandas Dataframe :param data: Sequence of tuples. Each tuple corresponds to a new plot and
containing the data to be plotted. Each dataframe is assumed to contain a column named "SNR" which is used is of the following form: [graph_title, pd.Dataframe, [line_label_1, line_label2, ...]].
as the x-axis Each dataframe is assumed to have an "SNR" column that is used as the x axis.
:param line_labels: Dictionary mapping column names to proper labels
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure :param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
:return: Matplotlib figure :return: Matplotlib figure
""" """
# Determine layout and create figure
num_graphs = len(data) num_graphs = len(data)
num_rows = _get_num_rows(num_graphs, num_cols) num_rows = _get_num_rows(num_graphs, num_cols)
@ -47,18 +49,19 @@ def show_BER_curves(title: str,
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array
for axis, name_data_pair in zip(axes, sorted(data.items())): # Populate axes
graph_name, df = name_data_pair
for axis, (graph_title, df, labels) in zip(axes, data):
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"] column_names = [column for column in df.columns.values.tolist() if not column == "SNR"]
for column in column_names: for column, label in zip(column_names, labels):
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=line_labels[column]) sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=label)
axis.set_title(graph_name) axis.set_title(graph_title)
axis.set(yscale="log") axis.set(yscale="log")
axis.set_xlabel("SNR") axis.set_xlabel("SNR")
axis.set_ylabel("BER") axis.set_ylabel("BER")
axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0]) axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
axis.legend() axis.legend()
return fig return fig