ba-thesis/sw/main.py

104 lines
2.9 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import typing
from pathlib import Path
import os
from itertools import chain
from decoders import proximal, naive_soft_decision
from utility import simulations, encoders, codes
def test_decoders(G, encoder, decoders: typing.List) -> pd.DataFrame:
k, n = G.shape
d = np.zeros(k) # All-zeros assumption
SNRs = np.linspace(1, 8, 8)
data = pd.DataFrame({"SNR": SNRs})
for decoder_name in decoders:
decoder = decoders[decoder_name]
_, BERs_sd = simulations.test_decoder(encoder=encoder,
decoder=decoder,
d=d,
SNRs=SNRs,
target_bit_errors=100,
N_max=30000)
data[f"BER_{decoder_name}"] = BERs_sd
return data
# TODO: Fix spacing between axes and margins
def plot_results():
results_dir = "sim_results"
code_paths = {}
for file in os.listdir(results_dir):
if file.endswith(".csv"):
code_paths[file.replace(".csv", "")] = os.path.join(results_dir, file)
sns.set_theme()
fig, axes = plt.subplots(2, len(code_paths) // 2, figsize=(12, 6))
fig.suptitle("Bit-Error-Rates of various decoders for different codes")
axes = list(chain.from_iterable(axes))
for i, code in enumerate(code_paths):
data = pd.read_csv(code_paths[code])
column_names = [column for column in data.columns.values.tolist()
if column.startswith("BER")]
ax = axes[i]
for column in column_names:
sns.lineplot(ax=ax, data=data, x="SNR", y=column, label=column.lstrip("BER_"))
ax.set_title(code)
ax.set(yscale="log")
ax.set_xlabel("SNR")
ax.set_ylabel("BER")
ax.set_yticks([10e-5, 10e-4, 10e-3, 10e-2, 10e-1, 10e0])
ax.legend()
plt.show()
def main():
Path("sim_results").mkdir(parents=True, exist_ok=True)
used_codes = [
"Hamming_7_4",
"Golay_24_12",
# "BCH_31_16",
# "BCH_31_21",
# "BCH_63_16",
]
for used_code in used_codes:
G = codes.Gs[used_code]
H = codes.get_systematic_H(G)
R = codes.get_systematic_R(G)
encoder = encoders.Encoder(G)
decoders = {
"naive_soft_decision": naive_soft_decision.SoftDecisionDecoder(G, H, R),
"proximal_0_01": proximal.ProximalDecoder(H, R, K=100, gamma=0.01),
"proximal_0_05": proximal.ProximalDecoder(H, R, K=100, gamma=0.05),
"proximal_0_15": proximal.ProximalDecoder(H, R, K=100, gamma=0.15),
}
# data = test_decoders(G, encoder, decoders)
# data.to_csv(f"sim_results/{used_code}.csv")
plot_results()
if __name__ == "__main__":
main()