Cleaned up simulate_2d_BER.py

This commit is contained in:
Andreas Tsouchlos 2022-12-05 15:11:57 +01:00
parent e6606959f1
commit 4c5e80c56e

View File

@ -1,14 +1,8 @@
import typing
import numpy as np import numpy as np
import pandas as pd
import seaborn as sns import seaborn as sns
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import signal import signal
from timeit import default_timer from timeit import default_timer
from tqdm import tqdm
from dataclasses import dataclass
from types import MappingProxyType
from utility import codes, noise, misc from utility import codes, noise, misc
from utility.simulation.simulators import GenericMultithreadedSimulator from utility.simulation.simulators import GenericMultithreadedSimulator
@ -16,10 +10,6 @@ from utility.simulation.simulators import GenericMultithreadedSimulator
from cpp_modules.cpp_decoders import ProximalDecoder_204_102 as ProximalDecoder from cpp_modules.cpp_decoders import ProximalDecoder_204_102 as ProximalDecoder
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
return np.sum(d != d_hat)
def task_func(params): def task_func(params):
"""Function called by the GenericMultithreadedSimulator instance. """Function called by the GenericMultithreadedSimulator instance.
@ -46,7 +36,7 @@ def task_func(params):
x = noise.add_awgn(x_bpsk, SNR, n, k) x = noise.add_awgn(x_bpsk, SNR, n, k)
x_hat, k_max = decoder.decode(x) x_hat, k_max = decoder.decode(x)
bit_errors = count_bit_errors(x_hat, c) bit_errors = misc.count_bit_errors(x_hat, c)
if bit_errors > 0: if bit_errors > 0:
total_bit_errors += bit_errors total_bit_errors += bit_errors
total_frame_errors += 1 total_frame_errors += 1
@ -67,16 +57,30 @@ def task_func(params):
"num_iterations": num_iterations} "num_iterations": num_iterations}
def simulate(H_file, SNRs, max_iterations, omega, K, gammas): def get_params():
sim = GenericMultithreadedSimulator() # Define global simulation parameters
# Define fixed simulation params # H_file = "BCH_7_4.alist"
# H_file = "BCH_31_11.alist"
# H_file = "BCH_31_26.alist"
# H_file = "96.3.965.alist"
H_file = "204.33.486.alist"
# H_file = "204.33.484.alist"
# H_file = "204.55.187.alist"
# H_file = "408.33.844.alist"
H = codes.read_alist_file(f"res/{H_file}") H = codes.read_alist_file(f"res/{H_file}")
n_min_k, n = H.shape n_min_k, n = H.shape
k = n - n_min_k k = n - n_min_k
# Define params different for each task omega = 0.05
K = 100
gammas = np.arange(0.0, 0.17, 0.01)
SNRs = np.arange(1, 6, 0.5)
max_iterations = 20000
# Define parameters different for each task
task_params = [] task_params = []
for i, SNR in enumerate(SNRs): for i, SNR in enumerate(SNRs):
@ -88,45 +92,29 @@ def simulate(H_file, SNRs, max_iterations, omega, K, gammas):
{"decoder": decoder, "max_iterations": max_iterations, {"decoder": decoder, "max_iterations": max_iterations,
"SNR": SNR, "gamma": gamma, "n": n, "k": k}) "SNR": SNR, "gamma": gamma, "n": n, "k": k})
# Set up simulation return task_params
sim.task_params = task_params
sim.task_func = task_func
sim.start_or_continue()
return sim.current_results
def main(): def main():
# Set up simulation params
sim_name = "2d_BER_FER_DFR" sim_name = "2d_BER_FER_DFR"
# H_file = "BCH_7_4.alist"
# H_file = "BCH_31_11.alist"
# H_file = "BCH_31_26.alist"
# H_file = "96.3.965.alist"
H_file = "204.33.486.alist"
# H_file = "204.33.484.alist"
# H_file = "204.55.187.alist"
# H_file = "408.33.844.alist"
SNRs = np.arange(1, 6, 0.5)
max_iterations = 20000
omega = 0.05
K = 100
gammas = np.arange(0.0, 0.17, 0.01)
# Run simulation # Run simulation
sim = GenericMultithreadedSimulator()
sim.task_params = get_params()
sim.task_func = task_func
start_time = default_timer() start_time = default_timer()
results = simulate(H_file, SNRs, max_iterations, omega, K, gammas) sim.start_or_continue()
end_time = default_timer() end_time = default_timer()
# Show results
print(f"duration: {end_time - start_time}") print(f"duration: {end_time - start_time}")
df = misc.pgf_reformat_data_3d(results=results, x_param_name="SNR", df = misc.pgf_reformat_data_3d(results=sim.current_results,
x_param_name="SNR",
y_param_name="gamma", y_param_name="gamma",
z_param_names=["BER", "FER", "DFR", z_param_names=["BER", "FER", "DFR",
"num_iterations"]) "num_iterations"])