68 lines
2.4 KiB
Python
68 lines
2.4 KiB
Python
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
import typing
|
|
from itertools import chain
|
|
import math
|
|
|
|
|
|
def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
|
"""Get the minimum number of rows needed to show a certain number of graphs,
|
|
given a certain number of columns.
|
|
|
|
:param num_graphs: Number of graphs
|
|
:param num_cols: Number of columns
|
|
:return: Number of rows
|
|
"""
|
|
return math.ceil(num_graphs / num_cols)
|
|
|
|
|
|
# TODO: Handle number of graphs not nicely fitting into rows and columns
|
|
def plot_BERs(title: str,
|
|
data: typing.Sequence[typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]],
|
|
num_cols: int = 3) -> plt.figure:
|
|
"""This function creates a matplotlib figure containing a number of plots.
|
|
|
|
The plots created are logarithmic and the scaling is adjusted to be sensible for BER plots.
|
|
|
|
:param title: Title of the figure
|
|
:param data: Sequence of tuples. Each tuple corresponds to a new plot and
|
|
is of the following form: [graph_title, pd.Dataframe, [line_label_1, line_label2, ...]].
|
|
Each dataframe is assumed to have an "SNR" column that is used as the x axis.
|
|
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
|
|
:return: Matplotlib figure
|
|
"""
|
|
# Determine layout and create figure
|
|
|
|
num_graphs = len(data)
|
|
num_rows = _get_num_rows(num_graphs, num_cols)
|
|
|
|
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*4, num_rows*4), squeeze=False)
|
|
fig.suptitle(title)
|
|
|
|
fig.subplots_adjust(left=0.1,
|
|
bottom=0.1,
|
|
right=0.9,
|
|
top=0.9,
|
|
wspace=0.3,
|
|
hspace=0.4)
|
|
|
|
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array
|
|
|
|
# Populate axes
|
|
|
|
for axis, (graph_title, df, labels) in zip(axes, data):
|
|
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"]
|
|
|
|
for column, label in zip(column_names, labels):
|
|
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=label)
|
|
|
|
axis.set_title(graph_title)
|
|
axis.set(yscale="log")
|
|
axis.set_xlabel("SNR")
|
|
axis.set_ylabel("BER")
|
|
axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
|
|
axis.legend()
|
|
|
|
return fig
|