Implemented proper titles and line labels in visualization.show_BER_curves()
This commit is contained in:
parent
524b57f41c
commit
d9009970ad
46
sw/plot_results.py
Normal file
46
sw/plot_results.py
Normal file
@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import pandas as pd
|
||||
import os
|
||||
from utility import visualization
|
||||
|
||||
|
||||
# TODO: Fix spacing between axes and margins
|
||||
def plot_results():
|
||||
graph_names = {"96.3.965": "n=96, k=48 - 965",
|
||||
"204.3.486": "n=204, k=102 - 486",
|
||||
"204.55.187": "n=204, k=102 - 187",
|
||||
"408.33.844": "n=408, k=204 - 844",
|
||||
"816.1A4.845": "n=816, k=272 - 843",
|
||||
"999.111.3.5543": "n=999, k=888 - 5543",
|
||||
"999.111.3.5565": "n=999, k=888 - 5565",
|
||||
"PEGReg252x504": "n=504, k=252 - PEGReg"}
|
||||
|
||||
line_labels = {"BER_ML": "ML",
|
||||
"BER_prox_0_15": "$\gamma = 0.15$",
|
||||
"BER_prox_0_05": "$\gamma = 0.05$",
|
||||
"BER_prox_0_01": "$\gamma = 0.01$"}
|
||||
|
||||
# Read data from files
|
||||
results_dir = "sim_results"
|
||||
|
||||
data = {}
|
||||
for file in os.listdir(results_dir):
|
||||
if file.endswith(".csv"):
|
||||
filename = os.path.splitext(file)[0]
|
||||
|
||||
df = pd.read_csv(os.path.join(results_dir, file))
|
||||
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
|
||||
data[graph_names[filename]] = df
|
||||
|
||||
# Create and show graphs
|
||||
|
||||
sns.set_theme()
|
||||
fig = visualization.show_BER_curves("Bit-Error-Rates of proximal decoder for different codes",
|
||||
data, num_cols=4, line_labels=line_labels)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
plot_results()
|
||||
@ -18,15 +18,18 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
||||
|
||||
|
||||
# TODO: Calculate fig size in relation to the number of rows and columns
|
||||
# TODO: Set proper line labels
|
||||
# TODO: Set proper axis titles
|
||||
# TODO: Should unnamed columns be dropped by this function or by the caller?
|
||||
# TODO: Handle number of graphs not nicely fitting into rows and columns
|
||||
def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.figure:
|
||||
def show_BER_curves(title: str,
|
||||
data: typing.Dict[str, pd.DataFrame],
|
||||
line_labels: typing.Dict[str, str],
|
||||
num_cols: int = 3) -> plt.figure:
|
||||
"""This function creates a matplotlib figure containing a number of BER curves.
|
||||
|
||||
:param data: List of pandas DataFrames containing the data to be plotted. Each dataframe in the list is plotted
|
||||
in a new graph. Each dataframe is assumed to contain a column named "SNR" which is used as the x-axis
|
||||
:param title: Title of the figure
|
||||
:param data: Dictionary where each key corresponds to the name of a new graph and the value is a pandas Dataframe
|
||||
containing the data to be plotted. Each dataframe is assumed to contain a column named "SNR" which is used
|
||||
as the x-axis
|
||||
:param line_labels: Dictionary mapping column names to proper labels
|
||||
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
|
||||
:return: Matplotlib figure
|
||||
"""
|
||||
@ -34,17 +37,18 @@ def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.f
|
||||
num_rows = _get_num_rows(num_graphs, num_cols)
|
||||
|
||||
fig, axes = plt.subplots(num_rows, num_cols, squeeze=False)
|
||||
fig.suptitle("Bit-Error-Rates of various decoders for different codes")
|
||||
fig.suptitle(title)
|
||||
|
||||
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array
|
||||
|
||||
for axis, df in zip(axes, data):
|
||||
for axis, name_data_pair in zip(axes, sorted(data.items())):
|
||||
graph_name, df = name_data_pair
|
||||
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"]
|
||||
|
||||
for column in column_names:
|
||||
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=column)
|
||||
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=line_labels[column])
|
||||
|
||||
#axis.set_title(code)
|
||||
axis.set_title(graph_name)
|
||||
axis.set(yscale="log")
|
||||
axis.set_xlabel("SNR")
|
||||
axis.set_ylabel("BER")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user