Changed the way data formatting is handled for the simulation API to be able to properly use the SimulationManager with the GenericMultithreadedSimulator
This commit is contained in:
parent
4c5e80c56e
commit
96fcf0dd11
@ -3,9 +3,11 @@ import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import signal
|
||||
from timeit import default_timer
|
||||
from functools import partial
|
||||
|
||||
from utility import codes, noise, misc
|
||||
from utility.simulation.simulators import GenericMultithreadedSimulator
|
||||
from utility.simulation import SimulationManager
|
||||
|
||||
from cpp_modules.cpp_decoders import ProximalDecoder_204_102 as ProximalDecoder
|
||||
|
||||
@ -57,19 +59,12 @@ def task_func(params):
|
||||
"num_iterations": num_iterations}
|
||||
|
||||
|
||||
def get_params():
|
||||
def get_params(code_name: str):
|
||||
# Define global simulation parameters
|
||||
|
||||
# H_file = "BCH_7_4.alist"
|
||||
# H_file = "BCH_31_11.alist"
|
||||
# H_file = "BCH_31_26.alist"
|
||||
# H_file = "96.3.965.alist"
|
||||
H_file = "204.33.486.alist"
|
||||
# H_file = "204.33.484.alist"
|
||||
# H_file = "204.55.187.alist"
|
||||
# H_file = "408.33.844.alist"
|
||||
H_file = f"res/{code_name}.alist"
|
||||
|
||||
H = codes.read_alist_file(f"res/{H_file}")
|
||||
H = codes.read_alist_file(H_file)
|
||||
n_min_k, n = H.shape
|
||||
k = n - n_min_k
|
||||
|
||||
@ -95,35 +90,51 @@ def get_params():
|
||||
return task_params
|
||||
|
||||
|
||||
def configure_new_simulation(sim_mgr: SimulationManager, code_name: str,
|
||||
sim_name: str) -> None:
|
||||
sim = GenericMultithreadedSimulator()
|
||||
|
||||
sim.task_params = get_params(code_name)
|
||||
sim.task_func = task_func
|
||||
sim.format_func = partial(misc.pgf_reformat_data_3d, x_param_name="SNR",
|
||||
y_param_name="gamma",
|
||||
z_param_names=["BER", "FER", "DFR",
|
||||
"num_iterations"])
|
||||
|
||||
sim_mgr.configure_simulation(simulator=sim, name=sim_name)
|
||||
|
||||
|
||||
def main():
|
||||
sim_name = "2d_BER_FER_DFR"
|
||||
# code_name = "BCH_7_4"
|
||||
# code_name = "BCH_31_11"
|
||||
# code_name = "BCH_31_26"
|
||||
# code_name = "96.3.965"
|
||||
# code_name = "204.33.486"
|
||||
code_name = "204.33.484"
|
||||
# code_name = "204.55.187"
|
||||
# code_name = "408.33.844"
|
||||
|
||||
sim_name = f"2d_BER_FER_DFR_{misc.slugify(code_name)}"
|
||||
|
||||
# Run simulation
|
||||
|
||||
sim = GenericMultithreadedSimulator()
|
||||
sim_mgr = SimulationManager(saves_dir="sim_saves",
|
||||
results_dir="sim_results")
|
||||
|
||||
sim.task_params = get_params()
|
||||
sim.task_func = task_func
|
||||
unfinished_sims = sim_mgr.get_unfinished()
|
||||
if len(unfinished_sims) > 0:
|
||||
sim_mgr.load_unfinished(unfinished_sims[0])
|
||||
else:
|
||||
configure_new_simulation(sim_mgr=sim_mgr, code_name=code_name,
|
||||
sim_name=sim_name)
|
||||
|
||||
start_time = default_timer()
|
||||
sim.start_or_continue()
|
||||
end_time = default_timer()
|
||||
sim_mgr.simulate()
|
||||
|
||||
# Show results
|
||||
|
||||
print(f"duration: {end_time - start_time}")
|
||||
|
||||
df = misc.pgf_reformat_data_3d(results=sim.current_results,
|
||||
x_param_name="SNR",
|
||||
y_param_name="gamma",
|
||||
z_param_names=["BER", "FER", "DFR",
|
||||
"num_iterations"])
|
||||
|
||||
# df.sort_values(by=["gamma", "SNR"]).to_csv(
|
||||
# f"sim_results/{sim_name}_{misc.slugify(H_file)}.csv", index=False)
|
||||
# Plot results
|
||||
|
||||
sns.set_theme()
|
||||
ax = sns.lineplot(data=df, x="SNR", y="BER", hue="gamma")
|
||||
ax = sns.lineplot(data=sim_mgr.get_current_results(), x="SNR", y="BER",
|
||||
hue="gamma")
|
||||
ax.set_yscale('log')
|
||||
ax.set_ylim((5e-5, 2e-0))
|
||||
plt.show()
|
||||
|
||||
@ -65,3 +65,8 @@ def pgf_reformat_data_3d(results: typing.Sequence, x_param_name: str,
|
||||
df[z_param_name] = zs[z_param_name]
|
||||
|
||||
return df.sort_values(by=[x_param_name, y_param_name])
|
||||
|
||||
|
||||
def count_bit_errors(x: np.array, x_hat: np.array) -> int:
|
||||
"""Count the number of different bits between two words."""
|
||||
return np.sum(x != x_hat)
|
||||
|
||||
@ -120,16 +120,9 @@ class SimulationDeSerializer:
|
||||
# Save metadata
|
||||
self._save_metadata(sim_name, metadata)
|
||||
|
||||
# Save results
|
||||
data = {}
|
||||
for key, value in simulator.get_current_results().items():
|
||||
if not isinstance(value, collections.abc.Sequence):
|
||||
value = [value]
|
||||
|
||||
data[misc.slugify(key)] = value
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.to_csv(self._get_results_path(sim_name))
|
||||
# Save current results
|
||||
simulator.current_results.to_csv(self._get_results_path(sim_name),
|
||||
index=False)
|
||||
|
||||
def read_results(self, sim_name: str) -> typing.Tuple[
|
||||
pd.DataFrame, typing.Dict]:
|
||||
@ -180,13 +173,13 @@ class SimulationManager:
|
||||
self._metadata is not None)
|
||||
|
||||
def configure_simulation(self, simulator: typing.Any, name: str,
|
||||
column_labels: typing.Sequence[str]) -> None:
|
||||
additional_metadata: dict = {}) -> None:
|
||||
"""Configure a new simulation."""
|
||||
self._simulator = simulator
|
||||
self._sim_name = name
|
||||
self._metadata["name"] = name
|
||||
self._metadata["labels"] = column_labels
|
||||
self._metadata["platform"] = platform.platform()
|
||||
self._metadata.update(additional_metadata)
|
||||
|
||||
def get_unfinished(self) -> typing.List[str]:
|
||||
"""Get a list of names of all present unfinished simulations."""
|
||||
@ -241,3 +234,6 @@ class SimulationManager:
|
||||
self._metadata)
|
||||
except KeyboardInterrupt:
|
||||
self._exit_gracefully()
|
||||
|
||||
def get_current_results(self) -> pd.DataFrame:
|
||||
return self._simulator.current_results
|
||||
|
||||
@ -8,8 +8,10 @@ from multiprocessing import Lock
|
||||
|
||||
from utility import noise
|
||||
|
||||
|
||||
# TODO: Fix ProximalDecoder_Dynamic
|
||||
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as ProximalDecoder
|
||||
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as
|
||||
# ProximalDecoder
|
||||
|
||||
|
||||
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
|
||||
@ -279,6 +281,7 @@ class HashableDict:
|
||||
|
||||
class GenericMultithreadedSimulator:
|
||||
def __init__(self, max_workers=8):
|
||||
self._format_func = None
|
||||
self._task_func = None
|
||||
self._task_params = None
|
||||
self._max_workers = max_workers
|
||||
@ -303,7 +306,19 @@ class GenericMultithreadedSimulator:
|
||||
def task_func(self, func):
|
||||
self._task_func = func
|
||||
|
||||
@property
|
||||
def format_func(self):
|
||||
return self._format_func
|
||||
|
||||
@format_func.setter
|
||||
def format_func(self, func):
|
||||
self._format_func = func
|
||||
|
||||
def start_or_continue(self):
|
||||
assert self._task_func is not None
|
||||
assert self._task_params is not None
|
||||
assert self._format_func is not None
|
||||
|
||||
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
|
||||
|
||||
with tqdm(total=(len(self._task_params)), leave=False) as pbar:
|
||||
@ -330,11 +345,11 @@ class GenericMultithreadedSimulator:
|
||||
def stop(self):
|
||||
assert self._executor is not None, "The simulation has to be started" \
|
||||
" before it can be stopped"
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
self._executor.shutdown(wait=True, cancel_futures=True)
|
||||
|
||||
@property
|
||||
def current_results(self):
|
||||
return self._results
|
||||
return self._format_func(self._results)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user