Finished first rough implementation of show_BER_curves()
This commit is contained in:
parent
2fae3ba3be
commit
fad296f759
93
sw/main.py
93
sw/main.py
@ -9,75 +9,58 @@ from itertools import chain
|
|||||||
from timeit import default_timer
|
from timeit import default_timer
|
||||||
|
|
||||||
from decoders import proximal, maximum_likelihood
|
from decoders import proximal, maximum_likelihood
|
||||||
from utility import simulations, codes
|
from utility import simulations, codes, visualization
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fix spacing between axes and margins
|
# TODO: Fix spacing between axes and margins
|
||||||
def plot_results():
|
def plot_results():
|
||||||
results_dir = "sim_results"
|
results_dir = "sim_results"
|
||||||
|
|
||||||
code_paths = {}
|
# Read data from files
|
||||||
|
|
||||||
|
data = []
|
||||||
for file in os.listdir(results_dir):
|
for file in os.listdir(results_dir):
|
||||||
if file.endswith(".csv"):
|
if file.endswith(".csv"):
|
||||||
code_paths[file.replace(".csv", "")] = os.path.join(results_dir, file)
|
df = pd.read_csv(os.path.join(results_dir, file))
|
||||||
|
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
|
||||||
|
data.append(df)
|
||||||
|
|
||||||
|
# Create and show graphs
|
||||||
|
|
||||||
sns.set_theme()
|
sns.set_theme()
|
||||||
|
fig = visualization.show_BER_curves(data)
|
||||||
fig, axes = plt.subplots(2, len(code_paths) // 2, figsize=(12, 6))
|
|
||||||
fig.suptitle("Bit-Error-Rates of various decoders for different codes")
|
|
||||||
|
|
||||||
axes = list(chain.from_iterable(axes))
|
|
||||||
|
|
||||||
for i, code in enumerate(code_paths):
|
|
||||||
data = pd.read_csv(code_paths[code])
|
|
||||||
|
|
||||||
column_names = [column for column in data.columns.values.tolist()
|
|
||||||
if column.startswith("BER")]
|
|
||||||
|
|
||||||
ax = axes[i]
|
|
||||||
|
|
||||||
for column in column_names:
|
|
||||||
sns.lineplot(ax=ax, data=data, x="SNR", y=column, label=column.lstrip("BER_"))
|
|
||||||
|
|
||||||
ax.set_title(code)
|
|
||||||
ax.set(yscale="log")
|
|
||||||
ax.set_xlabel("SNR")
|
|
||||||
ax.set_ylabel("BER")
|
|
||||||
ax.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
|
|
||||||
ax.legend()
|
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
Path("sim_results").mkdir(parents=True, exist_ok=True)
|
# Path("sim_results").mkdir(parents=True, exist_ok=True)
|
||||||
|
#
|
||||||
# used_code = "Hamming_7_4"
|
# # used_code = "Hamming_7_4"
|
||||||
# used_code = "Golay_24_12"
|
# # used_code = "Golay_24_12"
|
||||||
used_code = "BCH_31_16"
|
# used_code = "BCH_31_16"
|
||||||
# used_code = "BCH_31_21"
|
# # used_code = "BCH_31_21"
|
||||||
# used_code = "BCH_63_16"
|
# # used_code = "BCH_63_16"
|
||||||
|
#
|
||||||
G = codes.Gs[used_code]
|
# G = codes.Gs[used_code]
|
||||||
H = codes.get_systematic_H(G)
|
# H = codes.get_systematic_H(G)
|
||||||
|
#
|
||||||
decoders = [
|
# decoders = [
|
||||||
maximum_likelihood.MLDecoder(G, H),
|
# maximum_likelihood.MLDecoder(G, H),
|
||||||
proximal.ProximalDecoder(H, gamma=0.01),
|
# proximal.ProximalDecoder(H, gamma=0.01),
|
||||||
proximal.ProximalDecoder(H, gamma=0.05),
|
# proximal.ProximalDecoder(H, gamma=0.05),
|
||||||
proximal.ProximalDecoder(H, gamma=0.15)
|
# proximal.ProximalDecoder(H, gamma=0.15)
|
||||||
]
|
# ]
|
||||||
|
#
|
||||||
k, n = G.shape
|
# k, n = G.shape
|
||||||
SNRs, BERs = simulations.test_decoders(n, k, decoders, N_max=30000, target_frame_errors=100)
|
# SNRs, BERs = simulations.test_decoders(n, k, decoders, N_max=30000, target_frame_errors=100)
|
||||||
|
#
|
||||||
df = pd.DataFrame({"SNR": SNRs})
|
# df = pd.DataFrame({"SNR": SNRs})
|
||||||
df["BER_ML"] = BERs[0]
|
# df["BER_ML"] = BERs[0]
|
||||||
df["BER_prox_0_01"] = BERs[0]
|
# df["BER_prox_0_01"] = BERs[0]
|
||||||
df["BER_prox_0_05"] = BERs[1]
|
# df["BER_prox_0_05"] = BERs[1]
|
||||||
df["BER_prox_0_15"] = BERs[2]
|
# df["BER_prox_0_15"] = BERs[2]
|
||||||
|
#
|
||||||
df.to_csv(f"sim_results/{used_code}.csv")
|
# df.to_csv(f"sim_results/{used_code}.csv")
|
||||||
|
|
||||||
plot_results()
|
plot_results()
|
||||||
|
|
||||||
|
|||||||
52
sw/utility/visualization.py
Normal file
52
sw/utility/visualization.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import typing
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
|
||||||
|
def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
||||||
|
"""Get the minimum number of rows needed to show a certain number of graphs,
|
||||||
|
given a certain number of columns.
|
||||||
|
|
||||||
|
:param num_graphs: Number of graphs
|
||||||
|
:param num_cols: Number of columns
|
||||||
|
:return: Number of rows
|
||||||
|
"""
|
||||||
|
return num_graphs // num_cols + 1
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Calculate fig size in relation to the number of rows and columns
|
||||||
|
# TODO: Set proper line labels
|
||||||
|
# TODO: Set proper axis titles
|
||||||
|
# TODO: Should unnamed columns be dropped by this function or by the caller?
|
||||||
|
def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.figure:
|
||||||
|
"""This function creates a matplotlib figure containing a number of BER curves.
|
||||||
|
|
||||||
|
:param data: List of pandas DataFrames containing the data to be plotted. Each element in the list is plotted in
|
||||||
|
a new graph. Each dataframe is assumed to contain a column named "SNR" which is used as the x-axis
|
||||||
|
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
|
||||||
|
:return: Matplotlib figure
|
||||||
|
"""
|
||||||
|
num_graphs = len(data)
|
||||||
|
num_rows = _get_num_rows(num_cols, num_cols)
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(num_rows, num_cols)
|
||||||
|
fig.suptitle("Bit-Error-Rates of various decoders for different codes")
|
||||||
|
|
||||||
|
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array
|
||||||
|
|
||||||
|
for axis, df in zip(axes, data):
|
||||||
|
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"]
|
||||||
|
|
||||||
|
for column in column_names:
|
||||||
|
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=column)
|
||||||
|
|
||||||
|
#axis.set_title(code)
|
||||||
|
axis.set(yscale="log")
|
||||||
|
axis.set_xlabel("SNR")
|
||||||
|
axis.set_ylabel("BER")
|
||||||
|
axis.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
|
||||||
|
axis.legend()
|
||||||
|
|
||||||
|
return fig
|
||||||
Loading…
Reference in New Issue
Block a user