import seaborn as sns import matplotlib.pyplot as plt import pandas as pd import typing from itertools import chain import math def _get_num_rows(num_graphs: int, num_cols: int) -> int: """Get the minimum number of rows needed to show a certain number of graphs, given a certain number of columns. :param num_graphs: Number of graphs :param num_cols: Number of columns :return: Number of rows """ return math.ceil(num_graphs / num_cols) # TODO: Calculate fig size in relation to the number of rows and columns # TODO: Handle number of graphs not nicely fitting into rows and columns def show_BER_curves(title: str, data: typing.Dict[str, pd.DataFrame], line_labels: typing.Dict[str, str], num_cols: int = 3) -> plt.figure: """This function creates a matplotlib figure containing a number of BER curves. :param title: Title of the figure :param data: Dictionary where each key corresponds to the name of a new graph and the value is a pandas Dataframe containing the data to be plotted. Each dataframe is assumed to contain a column named "SNR" which is used as the x-axis :param line_labels: Dictionary mapping column names to proper labels :param num_cols: Number of columns in which the graphs should be arranged in the resulting figure :return: Matplotlib figure """ num_graphs = len(data) num_rows = _get_num_rows(num_graphs, num_cols) fig, axes = plt.subplots(num_rows, num_cols, squeeze=False) fig.suptitle(title) axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array for axis, name_data_pair in zip(axes, sorted(data.items())): graph_name, df = name_data_pair column_names = [column for column in df.columns.values.tolist() if not column == "SNR"] for column in column_names: sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=line_labels[column]) axis.set_title(graph_name) axis.set(yscale="log") axis.set_xlabel("SNR") axis.set_ylabel("BER") axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0]) axis.legend() return fig