From 6fc01b20ff29b430f4793580a531b0f4baef6868 Mon Sep 17 00:00:00 2001 From: Andreas Tsouchlos Date: Tue, 8 Nov 2022 19:10:51 +0100 Subject: [PATCH] Restructured main.py --- sw/main.py | 111 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 69 insertions(+), 42 deletions(-) diff --git a/sw/main.py b/sw/main.py index e52e45a..d5e5161 100644 --- a/sw/main.py +++ b/sw/main.py @@ -2,75 +2,102 @@ import numpy as np import matplotlib.pyplot as plt import seaborn as sns import pandas as pd -from timeit import default_timer as timer +import typing +from pathlib import Path +import os +from itertools import chain from decoders import proximal, naive_soft_decision -from utility import noise, simulations, encoders, codes +from utility import simulations, encoders, codes -def main(): - # used_code = "Hamming_7_4" - # used_code = "Golay_24_12" - # used_code = "BCH_31_16" - # used_code = "BCH_31_21" - used_code = "BCH_63_16" - - G = codes.Gs[used_code] - H = codes.get_systematic_H(G) - R = codes.get_systematic_R(G) - - # Define encoder and decoders - - encoder = encoders.Encoder(G) - - decoders = { - # "naive_soft_decision": naive_soft_decision.SoftDecisionDecoder(G, H, R), - "proximal_0_01": proximal.ProximalDecoder(H, R, K=100, gamma=0.01), - "proximal_0_05": proximal.ProximalDecoder(H, R, K=100, gamma=0.05), - "proximal_0_15": proximal.ProximalDecoder(H, R, K=100, gamma=0.15), - } - - # Test decoders - +def test_decoders(G, encoder, decoders: typing.List) -> pd.DataFrame: k, n = G.shape d = np.zeros(k) # All-zeros assumption - SNRs = np.linspace(1, 8, 9) + SNRs = np.linspace(1, 8, 8) data = pd.DataFrame({"SNR": SNRs}) - start_time = timer() - for decoder_name in decoders: decoder = decoders[decoder_name] _, BERs_sd = simulations.test_decoder(encoder=encoder, decoder=decoder, d=d, SNRs=SNRs, - target_bit_errors=500, - N_max=10000) + target_bit_errors=100, + N_max=30000) data[f"BER_{decoder_name}"] = BERs_sd - stop_time = timer() - print(f"Elapsed time: {stop_time - start_time:2f}") + return data - # Plot results + +# TODO: Fix spacing between axes and margins +def plot_results(): + results_dir = "sim_results" + + code_paths = {} + for file in os.listdir(results_dir): + if file.endswith(".csv"): + code_paths[file.replace(".csv", "")] = os.path.join(results_dir, file) sns.set_theme() - fig, axes = plt.subplots(1, 1) - fig.suptitle("Bit-Error-Rates of various decoders") + fig, axes = plt.subplots(2, len(code_paths) // 2, figsize=(12, 6)) + fig.suptitle("Bit-Error-Rates of various decoders for different codes") - for decoder_name in decoders: - ax = sns.lineplot(data=data, x="SNR", y=f"BER_{decoder_name}", label=f"{decoder_name}") + axes = list(chain.from_iterable(axes)) - ax.set_title(used_code) - ax.set(yscale="log") - ax.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0]) - ax.legend() + for i, code in enumerate(code_paths): + data = pd.read_csv(code_paths[code]) + + column_names = [column for column in data.columns.values.tolist() + if column.startswith("BER")] + + ax = axes[i] + + for column in column_names: + sns.lineplot(ax=ax, data=data, x="SNR", y=column, label=column.lstrip("BER_")) + + ax.set_title(code) + ax.set(yscale="log") + ax.set_xlabel("SNR") + ax.set_ylabel("BER") + ax.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0]) + ax.legend() plt.show() +def main(): + Path("sim_results").mkdir(parents=True, exist_ok=True) + + used_codes = [ + "Hamming_7_4", + "Golay_24_12", + # "BCH_31_16", + # "BCH_31_21", + # "BCH_63_16", + ] + + for used_code in used_codes: + G = codes.Gs[used_code] + H = codes.get_systematic_H(G) + R = codes.get_systematic_R(G) + + encoder = encoders.Encoder(G) + decoders = { + "naive_soft_decision": naive_soft_decision.SoftDecisionDecoder(G, H, R), + "proximal_0_01": proximal.ProximalDecoder(H, R, K=100, gamma=0.01), + "proximal_0_05": proximal.ProximalDecoder(H, R, K=100, gamma=0.05), + "proximal_0_15": proximal.ProximalDecoder(H, R, K=100, gamma=0.15), + } + + # data = test_decoders(G, encoder, decoders) + # data.to_csv(f"sim_results/{used_code}.csv") + + plot_results() + + if __name__ == "__main__": main()