Changed the way data formatting is handled for the simulation API to be able to properly use the SimulationManager with the GenericMultithreadedSimulator

This commit is contained in:
2022-12-05 16:25:47 +01:00
parent 4c5e80c56e
commit 96fcf0dd11
4 changed files with 72 additions and 45 deletions

View File

@@ -120,16 +120,9 @@ class SimulationDeSerializer:
# Save metadata
self._save_metadata(sim_name, metadata)
# Save results
data = {}
for key, value in simulator.get_current_results().items():
if not isinstance(value, collections.abc.Sequence):
value = [value]
data[misc.slugify(key)] = value
df = pd.DataFrame(data)
df.to_csv(self._get_results_path(sim_name))
# Save current results
simulator.current_results.to_csv(self._get_results_path(sim_name),
index=False)
def read_results(self, sim_name: str) -> typing.Tuple[
pd.DataFrame, typing.Dict]:
@@ -180,13 +173,13 @@ class SimulationManager:
self._metadata is not None)
def configure_simulation(self, simulator: typing.Any, name: str,
column_labels: typing.Sequence[str]) -> None:
additional_metadata: dict = {}) -> None:
"""Configure a new simulation."""
self._simulator = simulator
self._sim_name = name
self._metadata["name"] = name
self._metadata["labels"] = column_labels
self._metadata["platform"] = platform.platform()
self._metadata.update(additional_metadata)
def get_unfinished(self) -> typing.List[str]:
"""Get a list of names of all present unfinished simulations."""
@@ -241,3 +234,6 @@ class SimulationManager:
self._metadata)
except KeyboardInterrupt:
self._exit_gracefully()
def get_current_results(self) -> pd.DataFrame:
return self._simulator.current_results

View File

@@ -8,8 +8,10 @@ from multiprocessing import Lock
from utility import noise
# TODO: Fix ProximalDecoder_Dynamic
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as ProximalDecoder
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as
# ProximalDecoder
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
@@ -279,6 +281,7 @@ class HashableDict:
class GenericMultithreadedSimulator:
def __init__(self, max_workers=8):
self._format_func = None
self._task_func = None
self._task_params = None
self._max_workers = max_workers
@@ -303,7 +306,19 @@ class GenericMultithreadedSimulator:
def task_func(self, func):
self._task_func = func
@property
def format_func(self):
return self._format_func
@format_func.setter
def format_func(self, func):
self._format_func = func
def start_or_continue(self):
assert self._task_func is not None
assert self._task_params is not None
assert self._format_func is not None
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
with tqdm(total=(len(self._task_params)), leave=False) as pbar:
@@ -330,11 +345,11 @@ class GenericMultithreadedSimulator:
def stop(self):
assert self._executor is not None, "The simulation has to be started" \
" before it can be stopped"
self._executor.shutdown(wait=False, cancel_futures=True)
self._executor.shutdown(wait=True, cancel_futures=True)
@property
def current_results(self):
return self._results
return self._format_func(self._results)
def __getstate__(self):
state = self.__dict__.copy()