diff --git a/sw/simulate_2d_BER.py b/sw/simulate_2d_BER.py index 44c2c7d..e28cc20 100644 --- a/sw/simulate_2d_BER.py +++ b/sw/simulate_2d_BER.py @@ -3,9 +3,11 @@ import seaborn as sns import matplotlib.pyplot as plt import signal from timeit import default_timer +from functools import partial from utility import codes, noise, misc from utility.simulation.simulators import GenericMultithreadedSimulator +from utility.simulation import SimulationManager from cpp_modules.cpp_decoders import ProximalDecoder_204_102 as ProximalDecoder @@ -57,19 +59,12 @@ def task_func(params): "num_iterations": num_iterations} -def get_params(): +def get_params(code_name: str): # Define global simulation parameters - # H_file = "BCH_7_4.alist" - # H_file = "BCH_31_11.alist" - # H_file = "BCH_31_26.alist" - # H_file = "96.3.965.alist" - H_file = "204.33.486.alist" - # H_file = "204.33.484.alist" - # H_file = "204.55.187.alist" - # H_file = "408.33.844.alist" + H_file = f"res/{code_name}.alist" - H = codes.read_alist_file(f"res/{H_file}") + H = codes.read_alist_file(H_file) n_min_k, n = H.shape k = n - n_min_k @@ -95,35 +90,51 @@ def get_params(): return task_params +def configure_new_simulation(sim_mgr: SimulationManager, code_name: str, + sim_name: str) -> None: + sim = GenericMultithreadedSimulator() + + sim.task_params = get_params(code_name) + sim.task_func = task_func + sim.format_func = partial(misc.pgf_reformat_data_3d, x_param_name="SNR", + y_param_name="gamma", + z_param_names=["BER", "FER", "DFR", + "num_iterations"]) + + sim_mgr.configure_simulation(simulator=sim, name=sim_name) + + def main(): - sim_name = "2d_BER_FER_DFR" + # code_name = "BCH_7_4" + # code_name = "BCH_31_11" + # code_name = "BCH_31_26" + # code_name = "96.3.965" + # code_name = "204.33.486" + code_name = "204.33.484" + # code_name = "204.55.187" + # code_name = "408.33.844" + + sim_name = f"2d_BER_FER_DFR_{misc.slugify(code_name)}" # Run simulation - sim = GenericMultithreadedSimulator() + sim_mgr = SimulationManager(saves_dir="sim_saves", + results_dir="sim_results") - sim.task_params = get_params() - sim.task_func = task_func + unfinished_sims = sim_mgr.get_unfinished() + if len(unfinished_sims) > 0: + sim_mgr.load_unfinished(unfinished_sims[0]) + else: + configure_new_simulation(sim_mgr=sim_mgr, code_name=code_name, + sim_name=sim_name) - start_time = default_timer() - sim.start_or_continue() - end_time = default_timer() + sim_mgr.simulate() - # Show results - - print(f"duration: {end_time - start_time}") - - df = misc.pgf_reformat_data_3d(results=sim.current_results, - x_param_name="SNR", - y_param_name="gamma", - z_param_names=["BER", "FER", "DFR", - "num_iterations"]) - - # df.sort_values(by=["gamma", "SNR"]).to_csv( - # f"sim_results/{sim_name}_{misc.slugify(H_file)}.csv", index=False) + # Plot results sns.set_theme() - ax = sns.lineplot(data=df, x="SNR", y="BER", hue="gamma") + ax = sns.lineplot(data=sim_mgr.get_current_results(), x="SNR", y="BER", + hue="gamma") ax.set_yscale('log') ax.set_ylim((5e-5, 2e-0)) plt.show() diff --git a/sw/utility/misc.py b/sw/utility/misc.py index 401c2d9..6619015 100644 --- a/sw/utility/misc.py +++ b/sw/utility/misc.py @@ -65,3 +65,8 @@ def pgf_reformat_data_3d(results: typing.Sequence, x_param_name: str, df[z_param_name] = zs[z_param_name] return df.sort_values(by=[x_param_name, y_param_name]) + + +def count_bit_errors(x: np.array, x_hat: np.array) -> int: + """Count the number of different bits between two words.""" + return np.sum(x != x_hat) diff --git a/sw/utility/simulation/management.py b/sw/utility/simulation/management.py index 4fc5992..f325929 100644 --- a/sw/utility/simulation/management.py +++ b/sw/utility/simulation/management.py @@ -120,16 +120,9 @@ class SimulationDeSerializer: # Save metadata self._save_metadata(sim_name, metadata) - # Save results - data = {} - for key, value in simulator.get_current_results().items(): - if not isinstance(value, collections.abc.Sequence): - value = [value] - - data[misc.slugify(key)] = value - - df = pd.DataFrame(data) - df.to_csv(self._get_results_path(sim_name)) + # Save current results + simulator.current_results.to_csv(self._get_results_path(sim_name), + index=False) def read_results(self, sim_name: str) -> typing.Tuple[ pd.DataFrame, typing.Dict]: @@ -180,13 +173,13 @@ class SimulationManager: self._metadata is not None) def configure_simulation(self, simulator: typing.Any, name: str, - column_labels: typing.Sequence[str]) -> None: + additional_metadata: dict = {}) -> None: """Configure a new simulation.""" self._simulator = simulator self._sim_name = name self._metadata["name"] = name - self._metadata["labels"] = column_labels self._metadata["platform"] = platform.platform() + self._metadata.update(additional_metadata) def get_unfinished(self) -> typing.List[str]: """Get a list of names of all present unfinished simulations.""" @@ -241,3 +234,6 @@ class SimulationManager: self._metadata) except KeyboardInterrupt: self._exit_gracefully() + + def get_current_results(self) -> pd.DataFrame: + return self._simulator.current_results diff --git a/sw/utility/simulation/simulators.py b/sw/utility/simulation/simulators.py index cc98501..363afe4 100644 --- a/sw/utility/simulation/simulators.py +++ b/sw/utility/simulation/simulators.py @@ -8,8 +8,10 @@ from multiprocessing import Lock from utility import noise + # TODO: Fix ProximalDecoder_Dynamic -# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as ProximalDecoder +# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as +# ProximalDecoder def count_bit_errors(d: np.array, d_hat: np.array) -> int: @@ -279,6 +281,7 @@ class HashableDict: class GenericMultithreadedSimulator: def __init__(self, max_workers=8): + self._format_func = None self._task_func = None self._task_params = None self._max_workers = max_workers @@ -303,7 +306,19 @@ class GenericMultithreadedSimulator: def task_func(self, func): self._task_func = func + @property + def format_func(self): + return self._format_func + + @format_func.setter + def format_func(self, func): + self._format_func = func + def start_or_continue(self): + assert self._task_func is not None + assert self._task_params is not None + assert self._format_func is not None + self._executor = ProcessPoolExecutor(max_workers=self._max_workers) with tqdm(total=(len(self._task_params)), leave=False) as pbar: @@ -330,11 +345,11 @@ class GenericMultithreadedSimulator: def stop(self): assert self._executor is not None, "The simulation has to be started" \ " before it can be stopped" - self._executor.shutdown(wait=False, cancel_futures=True) + self._executor.shutdown(wait=True, cancel_futures=True) @property def current_results(self): - return self._results + return self._format_func(self._results) def __getstate__(self): state = self.__dict__.copy()