Restructured main.py

This commit is contained in:
Andreas Tsouchlos 2022-11-08 19:10:51 +01:00
parent d8258a36f6
commit 6fc01b20ff

View File

@ -2,75 +2,102 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
import pandas as pd import pandas as pd
from timeit import default_timer as timer import typing
from pathlib import Path
import os
from itertools import chain
from decoders import proximal, naive_soft_decision from decoders import proximal, naive_soft_decision
from utility import noise, simulations, encoders, codes from utility import simulations, encoders, codes
def main(): def test_decoders(G, encoder, decoders: typing.List) -> pd.DataFrame:
# used_code = "Hamming_7_4"
# used_code = "Golay_24_12"
# used_code = "BCH_31_16"
# used_code = "BCH_31_21"
used_code = "BCH_63_16"
G = codes.Gs[used_code]
H = codes.get_systematic_H(G)
R = codes.get_systematic_R(G)
# Define encoder and decoders
encoder = encoders.Encoder(G)
decoders = {
# "naive_soft_decision": naive_soft_decision.SoftDecisionDecoder(G, H, R),
"proximal_0_01": proximal.ProximalDecoder(H, R, K=100, gamma=0.01),
"proximal_0_05": proximal.ProximalDecoder(H, R, K=100, gamma=0.05),
"proximal_0_15": proximal.ProximalDecoder(H, R, K=100, gamma=0.15),
}
# Test decoders
k, n = G.shape k, n = G.shape
d = np.zeros(k) # All-zeros assumption d = np.zeros(k) # All-zeros assumption
SNRs = np.linspace(1, 8, 9) SNRs = np.linspace(1, 8, 8)
data = pd.DataFrame({"SNR": SNRs}) data = pd.DataFrame({"SNR": SNRs})
start_time = timer()
for decoder_name in decoders: for decoder_name in decoders:
decoder = decoders[decoder_name] decoder = decoders[decoder_name]
_, BERs_sd = simulations.test_decoder(encoder=encoder, _, BERs_sd = simulations.test_decoder(encoder=encoder,
decoder=decoder, decoder=decoder,
d=d, d=d,
SNRs=SNRs, SNRs=SNRs,
target_bit_errors=500, target_bit_errors=100,
N_max=10000) N_max=30000)
data[f"BER_{decoder_name}"] = BERs_sd data[f"BER_{decoder_name}"] = BERs_sd
stop_time = timer() return data
print(f"Elapsed time: {stop_time - start_time:2f}")
# Plot results
# TODO: Fix spacing between axes and margins
def plot_results():
results_dir = "sim_results"
code_paths = {}
for file in os.listdir(results_dir):
if file.endswith(".csv"):
code_paths[file.replace(".csv", "")] = os.path.join(results_dir, file)
sns.set_theme() sns.set_theme()
fig, axes = plt.subplots(1, 1) fig, axes = plt.subplots(2, len(code_paths) // 2, figsize=(12, 6))
fig.suptitle("Bit-Error-Rates of various decoders") fig.suptitle("Bit-Error-Rates of various decoders for different codes")
for decoder_name in decoders: axes = list(chain.from_iterable(axes))
ax = sns.lineplot(data=data, x="SNR", y=f"BER_{decoder_name}", label=f"{decoder_name}")
ax.set_title(used_code) for i, code in enumerate(code_paths):
ax.set(yscale="log") data = pd.read_csv(code_paths[code])
ax.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
ax.legend() column_names = [column for column in data.columns.values.tolist()
if column.startswith("BER")]
ax = axes[i]
for column in column_names:
sns.lineplot(ax=ax, data=data, x="SNR", y=column, label=column.lstrip("BER_"))
ax.set_title(code)
ax.set(yscale="log")
ax.set_xlabel("SNR")
ax.set_ylabel("BER")
ax.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
ax.legend()
plt.show() plt.show()
def main():
Path("sim_results").mkdir(parents=True, exist_ok=True)
used_codes = [
"Hamming_7_4",
"Golay_24_12",
# "BCH_31_16",
# "BCH_31_21",
# "BCH_63_16",
]
for used_code in used_codes:
G = codes.Gs[used_code]
H = codes.get_systematic_H(G)
R = codes.get_systematic_R(G)
encoder = encoders.Encoder(G)
decoders = {
"naive_soft_decision": naive_soft_decision.SoftDecisionDecoder(G, H, R),
"proximal_0_01": proximal.ProximalDecoder(H, R, K=100, gamma=0.01),
"proximal_0_05": proximal.ProximalDecoder(H, R, K=100, gamma=0.05),
"proximal_0_15": proximal.ProximalDecoder(H, R, K=100, gamma=0.15),
}
# data = test_decoders(G, encoder, decoders)
# data.to_csv(f"sim_results/{used_code}.csv")
plot_results()
if __name__ == "__main__": if __name__ == "__main__":
main() main()