Modified visualization code to wread metadata
This commit is contained in:
parent
d0ed9ffbaa
commit
9beda2231d
@ -3,41 +3,39 @@ import matplotlib.pyplot as plt
|
|||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import os
|
import os
|
||||||
from utility import visualization
|
from utility import visualization, simulation
|
||||||
|
|
||||||
|
|
||||||
def plot_results():
|
def plot_results():
|
||||||
graph_names = {"96.3.965": "n=96, k=48 - 965",
|
|
||||||
"204.3.486": "n=204, k=102 - 486",
|
|
||||||
"204.55.187": "n=204, k=102 - 187",
|
|
||||||
"408.33.844": "n=408, k=204 - 844",
|
|
||||||
"816.1A4.845": "n=816, k=272 - 843",
|
|
||||||
"999.111.3.5543": "n=999, k=888 - 5543",
|
|
||||||
"999.111.3.5565": "n=999, k=888 - 5565",
|
|
||||||
"PEGReg252x504": "n=504, k=252 - PEGReg"}
|
|
||||||
|
|
||||||
line_labels = {"BER_ML": "ML",
|
sim_names = [
|
||||||
"BER_prox_0_15": "$\gamma = 0.15$",
|
"96.3965",
|
||||||
"BER_prox_0_05": "$\gamma = 0.05$",
|
"204.3.486",
|
||||||
"BER_prox_0_01": "$\gamma = 0.01$"}
|
"204.55.187",
|
||||||
|
"408.33.844",
|
||||||
|
"816.1A4.845",
|
||||||
|
"999.111.3.5543",
|
||||||
|
"999.111.3.5565",
|
||||||
|
"PEGReg252x504"
|
||||||
|
]
|
||||||
|
|
||||||
# Read data from files
|
deserializer = simulation.SimulationDeSerializer(save_dir="sim_saves", results_dir="sim_results")
|
||||||
results_dir = "sim_results"
|
|
||||||
|
|
||||||
data = {}
|
data = []
|
||||||
for file in os.listdir(results_dir):
|
for sim_name in sim_names:
|
||||||
if file.endswith(".csv"):
|
df, metadata = deserializer.read_results(sim_name)
|
||||||
filename = os.path.splitext(file)[0]
|
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
|
||||||
|
|
||||||
df = pd.read_csv(os.path.join(results_dir, file))
|
graph_title = sim_name
|
||||||
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
|
line_labels = metadata["labels"]
|
||||||
data[graph_names[filename]] = df
|
|
||||||
|
|
||||||
# Create and show graphs
|
graph_tuple = (graph_title, df, line_labels)
|
||||||
|
data.append(graph_tuple)
|
||||||
|
|
||||||
sns.set_theme()
|
sns.set_theme()
|
||||||
fig = visualization.show_BER_curves("Bit-Error-Rates of proximal decoder for different codes",
|
fig = visualization.plot_BERs(title="Bit-Error-Rates of proximal decoder for different codes",
|
||||||
data, num_cols=4, line_labels=line_labels)
|
data=data, num_cols=4)
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -18,20 +18,22 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
|||||||
|
|
||||||
|
|
||||||
# TODO: Handle number of graphs not nicely fitting into rows and columns
|
# TODO: Handle number of graphs not nicely fitting into rows and columns
|
||||||
def show_BER_curves(title: str,
|
def plot_BERs(title: str,
|
||||||
data: typing.Dict[str, pd.DataFrame],
|
data: typing.Sequence[typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]],
|
||||||
line_labels: typing.Dict[str, str],
|
num_cols: int = 3) -> plt.figure:
|
||||||
num_cols: int = 3) -> plt.figure:
|
"""This function creates a matplotlib figure containing a number of plots.
|
||||||
"""This function creates a matplotlib figure containing a number of BER curves.
|
|
||||||
|
The plots created are logarithmic and the scaling is adjusted to be sensible for BER plots.
|
||||||
|
|
||||||
:param title: Title of the figure
|
:param title: Title of the figure
|
||||||
:param data: Dictionary where each key corresponds to the name of a new graph and the value is a pandas Dataframe
|
:param data: Sequence of tuples. Each tuple corresponds to a new plot and
|
||||||
containing the data to be plotted. Each dataframe is assumed to contain a column named "SNR" which is used
|
is of the following form: [graph_title, pd.Dataframe, [line_label_1, line_label2, ...]].
|
||||||
as the x-axis
|
Each dataframe is assumed to have an "SNR" column that is used as the x axis.
|
||||||
:param line_labels: Dictionary mapping column names to proper labels
|
|
||||||
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
|
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
|
||||||
:return: Matplotlib figure
|
:return: Matplotlib figure
|
||||||
"""
|
"""
|
||||||
|
# Determine layout and create figure
|
||||||
|
|
||||||
num_graphs = len(data)
|
num_graphs = len(data)
|
||||||
num_rows = _get_num_rows(num_graphs, num_cols)
|
num_rows = _get_num_rows(num_graphs, num_cols)
|
||||||
|
|
||||||
@ -47,18 +49,19 @@ def show_BER_curves(title: str,
|
|||||||
|
|
||||||
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array
|
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array
|
||||||
|
|
||||||
for axis, name_data_pair in zip(axes, sorted(data.items())):
|
# Populate axes
|
||||||
graph_name, df = name_data_pair
|
|
||||||
|
for axis, (graph_title, df, labels) in zip(axes, data):
|
||||||
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"]
|
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"]
|
||||||
|
|
||||||
for column in column_names:
|
for column, label in zip(column_names, labels):
|
||||||
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=line_labels[column])
|
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=label)
|
||||||
|
|
||||||
axis.set_title(graph_name)
|
axis.set_title(graph_title)
|
||||||
axis.set(yscale="log")
|
axis.set(yscale="log")
|
||||||
axis.set_xlabel("SNR")
|
axis.set_xlabel("SNR")
|
||||||
axis.set_ylabel("BER")
|
axis.set_ylabel("BER")
|
||||||
axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
|
axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
|
||||||
axis.legend()
|
axis.legend()
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user