Implemented GenericMultithreadedSimulator
This commit is contained in:
parent
cee9c90c23
commit
c276ad456c
@ -2,8 +2,12 @@ import pandas as pd
|
||||
import numpy as np
|
||||
import typing
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, process, wait
|
||||
from functools import partial
|
||||
from multiprocessing import Lock
|
||||
|
||||
from utility import noise
|
||||
from cpp_modules.cpp_decoders import ProximalDecoder
|
||||
|
||||
|
||||
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
|
||||
@ -176,10 +180,10 @@ class ProximalDecoderSimulator:
|
||||
self._BERs[self._curr_decoder_index][self._curr_SNRs_index] \
|
||||
= self._curr_num_bit_errors / (
|
||||
adj_num_iterations * self._n)
|
||||
self._avg_K[self._curr_decoder_index][self._curr_SNRs_index]\
|
||||
self._avg_K[self._curr_decoder_index][self._curr_SNRs_index] \
|
||||
= \
|
||||
self._avg_K[self._curr_decoder_index][
|
||||
self._curr_SNRs_index] / adj_num_iterations
|
||||
self._avg_K[self._curr_decoder_index][
|
||||
self._curr_SNRs_index] / adj_num_iterations
|
||||
|
||||
self._dec_fails[self._curr_decoder_index][self._curr_SNRs_index] \
|
||||
= self._curr_num_dec_fails
|
||||
@ -215,7 +219,7 @@ class ProximalDecoderSimulator:
|
||||
self._decoder_pbar.close()
|
||||
self._overall_pbar.close()
|
||||
|
||||
def start(self) -> None:
|
||||
def start_or_continue(self) -> None:
|
||||
"""Start the simulation.
|
||||
|
||||
This is a blocking call. A call to the stop() function
|
||||
@ -252,4 +256,72 @@ class ProximalDecoderSimulator:
|
||||
for i, avg_K in enumerate(self._avg_K):
|
||||
data[f"AvgK_{i}"] = avg_K
|
||||
|
||||
return pd.DataFrame(data)
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
class GenericMultithreadedSimulator:
|
||||
def __init__(self, max_workers=8):
|
||||
self._task_func = None
|
||||
self._task_params = None
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._results = {}
|
||||
self._executor = None
|
||||
|
||||
@property
|
||||
def task_params(self):
|
||||
return self._task_params
|
||||
|
||||
@task_params.setter
|
||||
def task_params(self, params):
|
||||
assert isinstance(params, dict)
|
||||
self._task_params = params
|
||||
|
||||
@property
|
||||
def task_func(self):
|
||||
return self._task_func
|
||||
|
||||
@task_func.setter
|
||||
def task_func(self, func):
|
||||
self._task_func = func
|
||||
|
||||
def start_or_continue(self):
|
||||
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
|
||||
|
||||
with tqdm(total=(len(self._task_params)), leave=False) as pbar:
|
||||
def done_callback(key, f):
|
||||
try:
|
||||
pbar.update(1)
|
||||
self._results[key] = f.result()
|
||||
del self._task_params[key]
|
||||
except process.BrokenProcessPool:
|
||||
# This exception is thrown when the program is
|
||||
# prematurely stopped with a KeyboardInterrupt
|
||||
# TODO: Make sure task_params have not been removed
|
||||
pass
|
||||
|
||||
futures = []
|
||||
|
||||
for key, params in self._task_params.items():
|
||||
future = self._executor.submit(self._task_func, params)
|
||||
future.add_done_callback(partial(done_callback, key))
|
||||
futures.append(future)
|
||||
|
||||
self._executor.shutdown(wait=True, cancel_futures=False)
|
||||
|
||||
def stop(self):
|
||||
assert self._executor is not None, "The simulation has to be started" \
|
||||
" before it can be stopped"
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
def get_current_results(self):
|
||||
return self._results
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["_executor"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self._executor = ProcessPoolExecutor()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user