ba-thesis/sw/utility/visualization.py

55 lines
2.0 KiB
Python

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import typing
from itertools import chain
import math
def _get_num_rows(num_graphs: int, num_cols: int) -> int:
"""Get the minimum number of rows needed to show a certain number of graphs,
given a certain number of columns.
:param num_graphs: Number of graphs
:param num_cols: Number of columns
:return: Number of rows
"""
return math.ceil(num_graphs / num_cols)
# TODO: Calculate fig size in relation to the number of rows and columns
# TODO: Set proper line labels
# TODO: Set proper axis titles
# TODO: Should unnamed columns be dropped by this function or by the caller?
# TODO: Handle number of graphs not nicely fitting into rows and columns
def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.figure:
"""This function creates a matplotlib figure containing a number of BER curves.
:param data: List of pandas DataFrames containing the data to be plotted. Each dataframe in the list is plotted
in a new graph. Each dataframe is assumed to contain a column named "SNR" which is used as the x-axis
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
:return: Matplotlib figure
"""
num_graphs = len(data)
num_rows = _get_num_rows(num_graphs, num_cols)
fig, axes = plt.subplots(num_rows, num_cols, squeeze=False)
fig.suptitle("Bit-Error-Rates of various decoders for different codes")
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array
for axis, df in zip(axes, data):
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"]
for column in column_names:
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=column)
#axis.set_title(code)
axis.set(yscale="log")
axis.set_xlabel("SNR")
axis.set_ylabel("BER")
axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
axis.legend()
return fig