Changed the way data formatting is handled for the simulation API to be able to properly use the SimulationManager with the GenericMultithreadedSimulator
This commit is contained in:
parent
4c5e80c56e
commit
96fcf0dd11
@ -3,9 +3,11 @@ import seaborn as sns
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import signal
|
import signal
|
||||||
from timeit import default_timer
|
from timeit import default_timer
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from utility import codes, noise, misc
|
from utility import codes, noise, misc
|
||||||
from utility.simulation.simulators import GenericMultithreadedSimulator
|
from utility.simulation.simulators import GenericMultithreadedSimulator
|
||||||
|
from utility.simulation import SimulationManager
|
||||||
|
|
||||||
from cpp_modules.cpp_decoders import ProximalDecoder_204_102 as ProximalDecoder
|
from cpp_modules.cpp_decoders import ProximalDecoder_204_102 as ProximalDecoder
|
||||||
|
|
||||||
@ -57,19 +59,12 @@ def task_func(params):
|
|||||||
"num_iterations": num_iterations}
|
"num_iterations": num_iterations}
|
||||||
|
|
||||||
|
|
||||||
def get_params():
|
def get_params(code_name: str):
|
||||||
# Define global simulation parameters
|
# Define global simulation parameters
|
||||||
|
|
||||||
# H_file = "BCH_7_4.alist"
|
H_file = f"res/{code_name}.alist"
|
||||||
# H_file = "BCH_31_11.alist"
|
|
||||||
# H_file = "BCH_31_26.alist"
|
|
||||||
# H_file = "96.3.965.alist"
|
|
||||||
H_file = "204.33.486.alist"
|
|
||||||
# H_file = "204.33.484.alist"
|
|
||||||
# H_file = "204.55.187.alist"
|
|
||||||
# H_file = "408.33.844.alist"
|
|
||||||
|
|
||||||
H = codes.read_alist_file(f"res/{H_file}")
|
H = codes.read_alist_file(H_file)
|
||||||
n_min_k, n = H.shape
|
n_min_k, n = H.shape
|
||||||
k = n - n_min_k
|
k = n - n_min_k
|
||||||
|
|
||||||
@ -95,35 +90,51 @@ def get_params():
|
|||||||
return task_params
|
return task_params
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def configure_new_simulation(sim_mgr: SimulationManager, code_name: str,
|
||||||
sim_name = "2d_BER_FER_DFR"
|
sim_name: str) -> None:
|
||||||
|
|
||||||
# Run simulation
|
|
||||||
|
|
||||||
sim = GenericMultithreadedSimulator()
|
sim = GenericMultithreadedSimulator()
|
||||||
|
|
||||||
sim.task_params = get_params()
|
sim.task_params = get_params(code_name)
|
||||||
sim.task_func = task_func
|
sim.task_func = task_func
|
||||||
|
sim.format_func = partial(misc.pgf_reformat_data_3d, x_param_name="SNR",
|
||||||
start_time = default_timer()
|
|
||||||
sim.start_or_continue()
|
|
||||||
end_time = default_timer()
|
|
||||||
|
|
||||||
# Show results
|
|
||||||
|
|
||||||
print(f"duration: {end_time - start_time}")
|
|
||||||
|
|
||||||
df = misc.pgf_reformat_data_3d(results=sim.current_results,
|
|
||||||
x_param_name="SNR",
|
|
||||||
y_param_name="gamma",
|
y_param_name="gamma",
|
||||||
z_param_names=["BER", "FER", "DFR",
|
z_param_names=["BER", "FER", "DFR",
|
||||||
"num_iterations"])
|
"num_iterations"])
|
||||||
|
|
||||||
# df.sort_values(by=["gamma", "SNR"]).to_csv(
|
sim_mgr.configure_simulation(simulator=sim, name=sim_name)
|
||||||
# f"sim_results/{sim_name}_{misc.slugify(H_file)}.csv", index=False)
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# code_name = "BCH_7_4"
|
||||||
|
# code_name = "BCH_31_11"
|
||||||
|
# code_name = "BCH_31_26"
|
||||||
|
# code_name = "96.3.965"
|
||||||
|
# code_name = "204.33.486"
|
||||||
|
code_name = "204.33.484"
|
||||||
|
# code_name = "204.55.187"
|
||||||
|
# code_name = "408.33.844"
|
||||||
|
|
||||||
|
sim_name = f"2d_BER_FER_DFR_{misc.slugify(code_name)}"
|
||||||
|
|
||||||
|
# Run simulation
|
||||||
|
|
||||||
|
sim_mgr = SimulationManager(saves_dir="sim_saves",
|
||||||
|
results_dir="sim_results")
|
||||||
|
|
||||||
|
unfinished_sims = sim_mgr.get_unfinished()
|
||||||
|
if len(unfinished_sims) > 0:
|
||||||
|
sim_mgr.load_unfinished(unfinished_sims[0])
|
||||||
|
else:
|
||||||
|
configure_new_simulation(sim_mgr=sim_mgr, code_name=code_name,
|
||||||
|
sim_name=sim_name)
|
||||||
|
|
||||||
|
sim_mgr.simulate()
|
||||||
|
|
||||||
|
# Plot results
|
||||||
|
|
||||||
sns.set_theme()
|
sns.set_theme()
|
||||||
ax = sns.lineplot(data=df, x="SNR", y="BER", hue="gamma")
|
ax = sns.lineplot(data=sim_mgr.get_current_results(), x="SNR", y="BER",
|
||||||
|
hue="gamma")
|
||||||
ax.set_yscale('log')
|
ax.set_yscale('log')
|
||||||
ax.set_ylim((5e-5, 2e-0))
|
ax.set_ylim((5e-5, 2e-0))
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|||||||
@ -65,3 +65,8 @@ def pgf_reformat_data_3d(results: typing.Sequence, x_param_name: str,
|
|||||||
df[z_param_name] = zs[z_param_name]
|
df[z_param_name] = zs[z_param_name]
|
||||||
|
|
||||||
return df.sort_values(by=[x_param_name, y_param_name])
|
return df.sort_values(by=[x_param_name, y_param_name])
|
||||||
|
|
||||||
|
|
||||||
|
def count_bit_errors(x: np.array, x_hat: np.array) -> int:
|
||||||
|
"""Count the number of different bits between two words."""
|
||||||
|
return np.sum(x != x_hat)
|
||||||
|
|||||||
@ -120,16 +120,9 @@ class SimulationDeSerializer:
|
|||||||
# Save metadata
|
# Save metadata
|
||||||
self._save_metadata(sim_name, metadata)
|
self._save_metadata(sim_name, metadata)
|
||||||
|
|
||||||
# Save results
|
# Save current results
|
||||||
data = {}
|
simulator.current_results.to_csv(self._get_results_path(sim_name),
|
||||||
for key, value in simulator.get_current_results().items():
|
index=False)
|
||||||
if not isinstance(value, collections.abc.Sequence):
|
|
||||||
value = [value]
|
|
||||||
|
|
||||||
data[misc.slugify(key)] = value
|
|
||||||
|
|
||||||
df = pd.DataFrame(data)
|
|
||||||
df.to_csv(self._get_results_path(sim_name))
|
|
||||||
|
|
||||||
def read_results(self, sim_name: str) -> typing.Tuple[
|
def read_results(self, sim_name: str) -> typing.Tuple[
|
||||||
pd.DataFrame, typing.Dict]:
|
pd.DataFrame, typing.Dict]:
|
||||||
@ -180,13 +173,13 @@ class SimulationManager:
|
|||||||
self._metadata is not None)
|
self._metadata is not None)
|
||||||
|
|
||||||
def configure_simulation(self, simulator: typing.Any, name: str,
|
def configure_simulation(self, simulator: typing.Any, name: str,
|
||||||
column_labels: typing.Sequence[str]) -> None:
|
additional_metadata: dict = {}) -> None:
|
||||||
"""Configure a new simulation."""
|
"""Configure a new simulation."""
|
||||||
self._simulator = simulator
|
self._simulator = simulator
|
||||||
self._sim_name = name
|
self._sim_name = name
|
||||||
self._metadata["name"] = name
|
self._metadata["name"] = name
|
||||||
self._metadata["labels"] = column_labels
|
|
||||||
self._metadata["platform"] = platform.platform()
|
self._metadata["platform"] = platform.platform()
|
||||||
|
self._metadata.update(additional_metadata)
|
||||||
|
|
||||||
def get_unfinished(self) -> typing.List[str]:
|
def get_unfinished(self) -> typing.List[str]:
|
||||||
"""Get a list of names of all present unfinished simulations."""
|
"""Get a list of names of all present unfinished simulations."""
|
||||||
@ -241,3 +234,6 @@ class SimulationManager:
|
|||||||
self._metadata)
|
self._metadata)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self._exit_gracefully()
|
self._exit_gracefully()
|
||||||
|
|
||||||
|
def get_current_results(self) -> pd.DataFrame:
|
||||||
|
return self._simulator.current_results
|
||||||
|
|||||||
@ -8,8 +8,10 @@ from multiprocessing import Lock
|
|||||||
|
|
||||||
from utility import noise
|
from utility import noise
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fix ProximalDecoder_Dynamic
|
# TODO: Fix ProximalDecoder_Dynamic
|
||||||
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as ProximalDecoder
|
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as
|
||||||
|
# ProximalDecoder
|
||||||
|
|
||||||
|
|
||||||
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
|
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
|
||||||
@ -279,6 +281,7 @@ class HashableDict:
|
|||||||
|
|
||||||
class GenericMultithreadedSimulator:
|
class GenericMultithreadedSimulator:
|
||||||
def __init__(self, max_workers=8):
|
def __init__(self, max_workers=8):
|
||||||
|
self._format_func = None
|
||||||
self._task_func = None
|
self._task_func = None
|
||||||
self._task_params = None
|
self._task_params = None
|
||||||
self._max_workers = max_workers
|
self._max_workers = max_workers
|
||||||
@ -303,7 +306,19 @@ class GenericMultithreadedSimulator:
|
|||||||
def task_func(self, func):
|
def task_func(self, func):
|
||||||
self._task_func = func
|
self._task_func = func
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format_func(self):
|
||||||
|
return self._format_func
|
||||||
|
|
||||||
|
@format_func.setter
|
||||||
|
def format_func(self, func):
|
||||||
|
self._format_func = func
|
||||||
|
|
||||||
def start_or_continue(self):
|
def start_or_continue(self):
|
||||||
|
assert self._task_func is not None
|
||||||
|
assert self._task_params is not None
|
||||||
|
assert self._format_func is not None
|
||||||
|
|
||||||
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
|
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
|
||||||
|
|
||||||
with tqdm(total=(len(self._task_params)), leave=False) as pbar:
|
with tqdm(total=(len(self._task_params)), leave=False) as pbar:
|
||||||
@ -330,11 +345,11 @@ class GenericMultithreadedSimulator:
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
assert self._executor is not None, "The simulation has to be started" \
|
assert self._executor is not None, "The simulation has to be started" \
|
||||||
" before it can be stopped"
|
" before it can be stopped"
|
||||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
self._executor.shutdown(wait=True, cancel_futures=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_results(self):
|
def current_results(self):
|
||||||
return self._results
|
return self._format_func(self._results)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user