Changed the way data formatting is handled for the simulation API to be able to properly use the SimulationManager with the GenericMultithreadedSimulator

This commit is contained in:
Andreas Tsouchlos 2022-12-05 16:25:47 +01:00
parent 4c5e80c56e
commit 96fcf0dd11
4 changed files with 72 additions and 45 deletions

View File

@ -3,9 +3,11 @@ import seaborn as sns
import matplotlib.pyplot as plt
import signal
from timeit import default_timer
from functools import partial
from utility import codes, noise, misc
from utility.simulation.simulators import GenericMultithreadedSimulator
from utility.simulation import SimulationManager
from cpp_modules.cpp_decoders import ProximalDecoder_204_102 as ProximalDecoder
@ -57,19 +59,12 @@ def task_func(params):
"num_iterations": num_iterations}
def get_params():
def get_params(code_name: str):
# Define global simulation parameters
# H_file = "BCH_7_4.alist"
# H_file = "BCH_31_11.alist"
# H_file = "BCH_31_26.alist"
# H_file = "96.3.965.alist"
H_file = "204.33.486.alist"
# H_file = "204.33.484.alist"
# H_file = "204.55.187.alist"
# H_file = "408.33.844.alist"
H_file = f"res/{code_name}.alist"
H = codes.read_alist_file(f"res/{H_file}")
H = codes.read_alist_file(H_file)
n_min_k, n = H.shape
k = n - n_min_k
@ -95,35 +90,51 @@ def get_params():
return task_params
def configure_new_simulation(sim_mgr: SimulationManager, code_name: str,
sim_name: str) -> None:
sim = GenericMultithreadedSimulator()
sim.task_params = get_params(code_name)
sim.task_func = task_func
sim.format_func = partial(misc.pgf_reformat_data_3d, x_param_name="SNR",
y_param_name="gamma",
z_param_names=["BER", "FER", "DFR",
"num_iterations"])
sim_mgr.configure_simulation(simulator=sim, name=sim_name)
def main():
sim_name = "2d_BER_FER_DFR"
# code_name = "BCH_7_4"
# code_name = "BCH_31_11"
# code_name = "BCH_31_26"
# code_name = "96.3.965"
# code_name = "204.33.486"
code_name = "204.33.484"
# code_name = "204.55.187"
# code_name = "408.33.844"
sim_name = f"2d_BER_FER_DFR_{misc.slugify(code_name)}"
# Run simulation
sim = GenericMultithreadedSimulator()
sim_mgr = SimulationManager(saves_dir="sim_saves",
results_dir="sim_results")
sim.task_params = get_params()
sim.task_func = task_func
unfinished_sims = sim_mgr.get_unfinished()
if len(unfinished_sims) > 0:
sim_mgr.load_unfinished(unfinished_sims[0])
else:
configure_new_simulation(sim_mgr=sim_mgr, code_name=code_name,
sim_name=sim_name)
start_time = default_timer()
sim.start_or_continue()
end_time = default_timer()
sim_mgr.simulate()
# Show results
print(f"duration: {end_time - start_time}")
df = misc.pgf_reformat_data_3d(results=sim.current_results,
x_param_name="SNR",
y_param_name="gamma",
z_param_names=["BER", "FER", "DFR",
"num_iterations"])
# df.sort_values(by=["gamma", "SNR"]).to_csv(
# f"sim_results/{sim_name}_{misc.slugify(H_file)}.csv", index=False)
# Plot results
sns.set_theme()
ax = sns.lineplot(data=df, x="SNR", y="BER", hue="gamma")
ax = sns.lineplot(data=sim_mgr.get_current_results(), x="SNR", y="BER",
hue="gamma")
ax.set_yscale('log')
ax.set_ylim((5e-5, 2e-0))
plt.show()

View File

@ -65,3 +65,8 @@ def pgf_reformat_data_3d(results: typing.Sequence, x_param_name: str,
df[z_param_name] = zs[z_param_name]
return df.sort_values(by=[x_param_name, y_param_name])
def count_bit_errors(x: np.array, x_hat: np.array) -> int:
"""Count the number of different bits between two words."""
return np.sum(x != x_hat)

View File

@ -120,16 +120,9 @@ class SimulationDeSerializer:
# Save metadata
self._save_metadata(sim_name, metadata)
# Save results
data = {}
for key, value in simulator.get_current_results().items():
if not isinstance(value, collections.abc.Sequence):
value = [value]
data[misc.slugify(key)] = value
df = pd.DataFrame(data)
df.to_csv(self._get_results_path(sim_name))
# Save current results
simulator.current_results.to_csv(self._get_results_path(sim_name),
index=False)
def read_results(self, sim_name: str) -> typing.Tuple[
pd.DataFrame, typing.Dict]:
@ -180,13 +173,13 @@ class SimulationManager:
self._metadata is not None)
def configure_simulation(self, simulator: typing.Any, name: str,
column_labels: typing.Sequence[str]) -> None:
additional_metadata: dict = {}) -> None:
"""Configure a new simulation."""
self._simulator = simulator
self._sim_name = name
self._metadata["name"] = name
self._metadata["labels"] = column_labels
self._metadata["platform"] = platform.platform()
self._metadata.update(additional_metadata)
def get_unfinished(self) -> typing.List[str]:
"""Get a list of names of all present unfinished simulations."""
@ -241,3 +234,6 @@ class SimulationManager:
self._metadata)
except KeyboardInterrupt:
self._exit_gracefully()
def get_current_results(self) -> pd.DataFrame:
return self._simulator.current_results

View File

@ -8,8 +8,10 @@ from multiprocessing import Lock
from utility import noise
# TODO: Fix ProximalDecoder_Dynamic
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as ProximalDecoder
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as
# ProximalDecoder
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
@ -279,6 +281,7 @@ class HashableDict:
class GenericMultithreadedSimulator:
def __init__(self, max_workers=8):
self._format_func = None
self._task_func = None
self._task_params = None
self._max_workers = max_workers
@ -303,7 +306,19 @@ class GenericMultithreadedSimulator:
def task_func(self, func):
self._task_func = func
@property
def format_func(self):
return self._format_func
@format_func.setter
def format_func(self, func):
self._format_func = func
def start_or_continue(self):
assert self._task_func is not None
assert self._task_params is not None
assert self._format_func is not None
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
with tqdm(total=(len(self._task_params)), leave=False) as pbar:
@ -330,11 +345,11 @@ class GenericMultithreadedSimulator:
def stop(self):
assert self._executor is not None, "The simulation has to be started" \
" before it can be stopped"
self._executor.shutdown(wait=False, cancel_futures=True)
self._executor.shutdown(wait=True, cancel_futures=True)
@property
def current_results(self):
return self._results
return self._format_func(self._results)
def __getstate__(self):
state = self.__dict__.copy()