Modified visualization code to wread metadata

This commit is contained in:
2022-11-21 13:28:00 +01:00
parent d0ed9ffbaa
commit 9beda2231d
2 changed files with 45 additions and 44 deletions

View File

@@ -18,20 +18,22 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int:
# TODO: Handle number of graphs not nicely fitting into rows and columns
def show_BER_curves(title: str,
data: typing.Dict[str, pd.DataFrame],
line_labels: typing.Dict[str, str],
num_cols: int = 3) -> plt.figure:
"""This function creates a matplotlib figure containing a number of BER curves.
def plot_BERs(title: str,
data: typing.Sequence[typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]],
num_cols: int = 3) -> plt.figure:
"""This function creates a matplotlib figure containing a number of plots.
The plots created are logarithmic and the scaling is adjusted to be sensible for BER plots.
:param title: Title of the figure
:param data: Dictionary where each key corresponds to the name of a new graph and the value is a pandas Dataframe
containing the data to be plotted. Each dataframe is assumed to contain a column named "SNR" which is used
as the x-axis
:param line_labels: Dictionary mapping column names to proper labels
:param data: Sequence of tuples. Each tuple corresponds to a new plot and
is of the following form: [graph_title, pd.Dataframe, [line_label_1, line_label2, ...]].
Each dataframe is assumed to have an "SNR" column that is used as the x axis.
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
:return: Matplotlib figure
"""
# Determine layout and create figure
num_graphs = len(data)
num_rows = _get_num_rows(num_graphs, num_cols)
@@ -47,18 +49,19 @@ def show_BER_curves(title: str,
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array
for axis, name_data_pair in zip(axes, sorted(data.items())):
graph_name, df = name_data_pair
# Populate axes
for axis, (graph_title, df, labels) in zip(axes, data):
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"]
for column in column_names:
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=line_labels[column])
for column, label in zip(column_names, labels):
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=label)
axis.set_title(graph_name)
axis.set(yscale="log")
axis.set_xlabel("SNR")
axis.set_ylabel("BER")
axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
axis.legend()
axis.set_title(graph_title)
axis.set(yscale="log")
axis.set_xlabel("SNR")
axis.set_ylabel("BER")
axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
axis.legend()
return fig