import seaborn as sns import matplotlib.pyplot as plt import pandas as pd import typing from itertools import chain import math def _get_num_rows(num_graphs: int, num_cols: int) -> int: """Get the minimum number of rows needed to show a certain number of graphs, given a certain number of columns. :param num_graphs: Number of graphs :param num_cols: Number of columns :return: Number of rows """ return math.ceil(num_graphs / num_cols) # TODO: Handle number of graphs not nicely fitting into rows and columns def plot_BERs(title: str, data: typing.Sequence[ typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]], num_cols: int = 3) -> plt.figure: """This function creates a matplotlib figure containing a number of plots. The plots created are logarithmic and the scaling is adjusted to be sensible for BER plots. :param title: Title of the figure :param data: Sequence of tuples. Each tuple corresponds to a new plot and is of the following form: [graph_title, pd.Dataframe, [line_label_1, line_label2, ...]]. Each dataframe is assumed to have an "SNR" column that is used as the x axis. :param num_cols: Number of columns in which the graphs should be arranged in the resulting figure :return: Matplotlib figure """ # Determine layout and create figure num_graphs = len(data) num_rows = _get_num_rows(num_graphs, num_cols) fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 4, num_rows * 4), squeeze=False) fig.suptitle(title) fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4) axes = list(chain.from_iterable(axes))[ :num_graphs] # Flatten the 2d axes array # Populate axes for axis, (graph_title, df, labels) in zip(axes, data): column_names = [column for column in df.columns.values.tolist() if not column == "SNR"] for column, label in zip(column_names, labels): sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=label) axis.set_title(graph_title) axis.set(yscale="log") axis.set_xlabel("SNR") axis.set_ylabel("BER") axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0]) axis.legend() return fig