Reformatted all code

This commit is contained in:
Andreas Tsouchlos 2022-11-22 17:35:51 +01:00
parent 3c27a0cc18
commit 8913246600
11 changed files with 164 additions and 94 deletions

View File

@ -1 +1,2 @@
"""This package contains a number of different decoder implementations for LDPC codes.""" """This package contains a number of different decoder implementations for
LDPC codes."""

View File

@ -3,9 +3,9 @@ import itertools
class MLDecoder: class MLDecoder:
"""This class naively implements a soft decision decoder. The decoder calculates """This class naively implements a soft decision decoder. The decoder
the correlation between the received signal and each codeword and then chooses the calculates the correlation between the received signal and each codeword
one with the largest correlation. and then chooses the one with the largest correlation.
""" """
def __init__(self, G: np.array, H: np.array): def __init__(self, G: np.array, H: np.array):
@ -17,13 +17,15 @@ class MLDecoder:
self._G = G self._G = G
self._H = H self._H = H
self._datawords, self._codewords = self._gen_codewords() self._datawords, self._codewords = self._gen_codewords()
self._codewords_bpsk = 1 - 2 * self._codewords # The codewords, but mapped to [-1, 1]^n
# The codewords, but mapped to [-1, 1]^n
self._codewords_bpsk = 1 - 2 * self._codewords
def _gen_codewords(self) -> np.array: def _gen_codewords(self) -> np.array:
"""Generate a list of all possible codewords. """Generate a list of all possible codewords.
:return: Numpy array of the form [[codeword_1], [codeword_2], ...] :return: Numpy array of the form [[codeword_1], [codeword_2], ...]
(Each generated codeword is an element of [0, 1]^n) (Each generated codeword is an element of [0, 1]^n)
""" """
k, n = self._G.shape k, n = self._G.shape
@ -41,8 +43,8 @@ class MLDecoder:
This function assumes a BPSK modulated signal. This function assumes a BPSK modulated signal.
:param y: Vector of received values. (y = x + w, where 'x' is element of [-1, 1]^n :param y: Vector of received values. (y = x + w, where 'x' is
and 'w' is noise) element of [-1, 1]^n and 'w' is noise)
:return: Most probably sent codeword (element of [0, 1]^k) :return: Most probably sent codeword (element of [0, 1]^k)
""" """
correlations = np.dot(self._codewords_bpsk, y) correlations = np.dot(self._codewords_bpsk, y)

View File

@ -2,7 +2,8 @@ import numpy as np
class ProximalDecoder: class ProximalDecoder:
"""Class implementing the Proximal Decoding algorithm. See "Proximal Decoding for LDPC Codes" """Class implementing the Proximal Decoding algorithm. See "Proximal
Decoding for LDPC Codes"
by Tadashi Wadayama, and Satoshi Takabe. by Tadashi Wadayama, and Satoshi Takabe.
""" """
@ -13,7 +14,8 @@ class ProximalDecoder:
:param H: Parity Check Matrix :param H: Parity Check Matrix
:param K: Max number of iterations to perform when decoding :param K: Max number of iterations to perform when decoding
:param omega: Step size for the gradient descent process :param omega: Step size for the gradient descent process
:param gamma: Positive constant. Arises in the approximation of the prior PDF :param gamma: Positive constant. Arises in the approximation of the
prior PDF
:param eta: Positive constant slightly larger than one. See 3.2, p. 3 :param eta: Positive constant slightly larger than one. See 3.2, p. 3
""" """
self._H = H self._H = H
@ -27,8 +29,8 @@ class ProximalDecoder:
@staticmethod @staticmethod
def _L_awgn(s: np.array, y: np.array) -> np.array: def _L_awgn(s: np.array, y: np.array) -> np.array:
"""Variation of the negative log-likelihood for the special case of AWGN noise. """Variation of the negative log-likelihood for the special case of
See 4.1, p. 4. AWGN noise. See 4.1, p. 4.
""" """
return s - y return s - y
@ -41,15 +43,15 @@ class ProximalDecoder:
# Calculate gradient # Calculate gradient
sums = np.dot(A_prods**2 - A_prods, self._H) sums = np.dot(A_prods ** 2 - A_prods, self._H)
result = 4 * (x**2 - 1) * x + (2 / x) * sums result = 4 * (x ** 2 - 1) * x + (2 / x) * sums
return result return result
def _projection(self, v): def _projection(self, v):
"""Project a vector onto [-eta, eta]^n in order to avoid numerical instability. """Project a vector onto [-eta, eta]^n in order to avoid numerical
Detailed in 3.2, p. 3 (Equation (15)). instability. Detailed in 3.2, p. 3 (Equation (15)).
:param v: Vector to project :param v: Vector to project
:return: x clipped to [-eta, eta]^n :return: x clipped to [-eta, eta]^n
@ -60,7 +62,8 @@ class ProximalDecoder:
"""Perform a parity check for a given codeword. """Perform a parity check for a given codeword.
:param x_hat: codeword to be checked (element of [0, 1]^n) :param x_hat: codeword to be checked (element of [0, 1]^n)
:return: True if the parity check passes, i.e. the codeword is valid. False otherwise :return: True if the parity check passes, i.e. the codeword is
valid. False otherwise
""" """
syndrome = np.dot(self._H, x_hat) % 2 syndrome = np.dot(self._H, x_hat) % 2
return not np.any(syndrome) return not np.any(syndrome)
@ -70,8 +73,8 @@ class ProximalDecoder:
This function assumes a BPSK modulated signal and an AWGN channel. This function assumes a BPSK modulated signal and an AWGN channel.
:param y: Vector of received values. (y = x + w, where 'x' is element of [-1, 1]^n :param y: Vector of received values. (y = x + w, where 'x' is
and 'w' is noise) element of [-1, 1]^n and 'w' is noise)
:return: Most probably sent codeword (element of [0, 1]^n) :return: Most probably sent codeword (element of [0, 1]^n)
""" """
s = np.zeros(self._n) s = np.zeros(self._n)
@ -83,7 +86,9 @@ class ProximalDecoder:
s = self._projection(s) # Equation (15) s = self._projection(s) # Equation (15)
x_hat = np.sign(s) x_hat = np.sign(s)
x_hat = (x_hat == -1) * 1 # Map the codeword from [-1, 1]^n to [0, 1]^n
# Map the codeword from [ -1, 1]^n to [0, 1]^n
x_hat = (x_hat == -1) * 1
if self._check_parity(x_hat): if self._check_parity(x_hat):
break break

View File

@ -1,7 +1,5 @@
import numpy as np import numpy as np
import pandas as pd
from pathlib import Path from pathlib import Path
import sys
from decoders import proximal, maximum_likelihood from decoders import proximal, maximum_likelihood
from utility import simulation, codes from utility import simulation, codes
@ -12,7 +10,8 @@ def main():
sim_name = "test" sim_name = "test"
sim_mgr = simulation.SimulationManager(results_dir="sim_results", save_dir="sim_saves") sim_mgr = simulation.SimulationManager(results_dir="sim_results",
save_dir="sim_saves")
if sim_mgr.unfinished_simulation_present(sim_name): if sim_mgr.unfinished_simulation_present(sim_name):
print("Found unfinished simulation. Picking up where it was left of") print("Found unfinished simulation. Picking up where it was left of")
@ -28,8 +27,7 @@ def main():
# H = codes.read_alist_file("res/999.111.3.5543.alist") # H = codes.read_alist_file("res/999.111.3.5543.alist")
# H = codes.read_alist_file("res/999.111.3.5565.alist") # H = codes.read_alist_file("res/999.111.3.5565.alist")
H = codes.read_alist_file("res/816.1A4.845.alist") H = codes.read_alist_file("res/816.1A4.845.alist")
k = 272 k, n = H.shape
n = 816
decoders = [ decoders = [
proximal.ProximalDecoder(H, gamma=0.01), proximal.ProximalDecoder(H, gamma=0.01),
@ -43,9 +41,12 @@ def main():
"proximal $\\gamma = 0.15$" "proximal $\\gamma = 0.15$"
] ]
sim = simulation.Simulator(n=n, k=k, decoders=decoders, target_frame_errors=3, SNRs=np.arange(1, 6, 0.5)) sim = simulation.Simulator(n=n, k=k, decoders=decoders,
target_frame_errors=3,
SNRs=np.arange(1, 6, 0.5))
sim_mgr.configure_simulation(simulator=sim, name=sim_name, column_labels=labels) sim_mgr.configure_simulation(simulator=sim, name=sim_name,
column_labels=labels)
sim_mgr.simulate() sim_mgr.simulate()

View File

@ -7,7 +7,6 @@ from utility import visualization, simulation
def plot_results(): def plot_results():
sim_names = [ sim_names = [
"96.3965", "96.3965",
"204.3.486", "204.3.486",
@ -19,7 +18,8 @@ def plot_results():
"PEGReg252x504" "PEGReg252x504"
] ]
deserializer = simulation.SimulationDeSerializer(save_dir="sim_saves", results_dir="sim_results") deserializer = simulation.SimulationDeSerializer(save_dir="sim_saves",
results_dir="sim_results")
data = [] data = []
for sim_name in sim_names: for sim_name in sim_names:
@ -33,8 +33,9 @@ def plot_results():
data.append(graph_tuple) data.append(graph_tuple)
sns.set_theme() sns.set_theme()
fig = visualization.plot_BERs(title="Bit-Error-Rates of proximal decoder for different codes", fig = visualization.plot_BERs(
data=data, num_cols=4) title="Bit-Error-Rates of proximal decoder for different codes",
data=data, num_cols=4)
plt.show() plt.show()

View File

@ -1 +1,2 @@
"""This package contains various utilities that can be used in combination with the decoders.""" """This package contains various utilities that can be used in combination
with the decoders."""

View File

@ -1,8 +1,8 @@
"""This file Helper functions for generating an H matrix from alist data. """This file Helper functions for generating an H matrix from alist data.
Code from https://github.com/gnuradio/gnuradio/blob/master/gr-fec/python/fec/LDPC/Generate_LDPC_matrix_functions.py Code from https://github.com/gnuradio/gnuradio/blob/master/gr-fec/python/fec
/LDPC/Generate_LDPC_matrix_functions.py
""" """
import numpy as np import numpy as np
@ -44,6 +44,8 @@ def read_alist_file(filename):
# #
# @formatter:off
Gs = {'Hamming_7_4': np.array([[1, 0, 0, 0, 0, 1, 1], Gs = {'Hamming_7_4': np.array([[1, 0, 0, 0, 0, 1, 1],
[0, 1, 0, 0, 1, 0, 1], [0, 1, 0, 0, 1, 0, 1],
[0, 0, 1, 0, 1, 1, 0], [0, 0, 1, 0, 1, 1, 0],
@ -219,6 +221,8 @@ Gs = {'Hamming_7_4': np.array([[1, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1]]) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1]])
} }
# @formatter:on
# #
# Utilities for systematic codes # Utilities for systematic codes

View File

@ -4,7 +4,8 @@ import re
def slugify(value, allow_unicode=False): def slugify(value, allow_unicode=False):
""" """
Taken from https://github.com/django/django/blob/master/django/utils/text.py Taken from https://github.com/django/django/blob/master/django/utils
/text.py
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics, dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Convert to lowercase. Also strip leading and underscores, or hyphens. Convert to lowercase. Also strip leading and
@ -14,6 +15,8 @@ def slugify(value, allow_unicode=False):
if allow_unicode: if allow_unicode:
value = unicodedata.normalize('NFKC', value) value = unicodedata.normalize('NFKC', value)
else: else:
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') value = unicodedata.normalize('NFKD', value).encode('ascii',
'ignore').decode(
'ascii')
value = re.sub(r'[^\w\s-]', '', value.lower()) value = re.sub(r'[^\w\s-]', '', value.lower())
return re.sub(r'[-\s]+', '-', value).strip('-_') return re.sub(r'[-\s]+', '-', value).strip('-_')

View File

@ -1,11 +1,11 @@
"""Utility functions relating to noise and SNR calculations.""" """Utility functions relating to noise and SNR calculations."""
import numpy as np import numpy as np
def get_noise_variance_from_SNR(SNR: float, n: int, k: int) -> float: def get_noise_variance_from_SNR(SNR: float, n: int, k: int) -> float:
"""Calculate the variance of the noise from an SNR and the signal amplitude. """Calculate the variance of the noise from an SNR and the signal
amplitude.
:param SNR: Signal-to-Noise-Ratio in dB (E_b/N_0) :param SNR: Signal-to-Noise-Ratio in dB (E_b/N_0)
:param n: Length of a codeword of the used code :param n: Length of a codeword of the used code
@ -13,14 +13,15 @@ def get_noise_variance_from_SNR(SNR: float, n: int, k: int) -> float:
:return: Variance of the noise :return: Variance of the noise
""" """
SNR_linear = 10 ** (SNR / 10) SNR_linear = 10 ** (SNR / 10)
variance = 1 / (2 * (k/n) * SNR_linear) variance = 1 / (2 * (k / n) * SNR_linear)
return variance return variance
def add_awgn(c: np.array, SNR: float, n: int, k: int) -> np.array: def add_awgn(c: np.array, SNR: float, n: int, k: int) -> np.array:
"""Add Additive White Gaussian Noise to a data vector. As this function adds random noise to """Add Additive White Gaussian Noise to a data vector. As this function
the input, the output changes, even if it is called multiple times with the same input. adds random noise to the input, the output changes, even if it is called
multiple times with the same input.
:param c: Binary vector representing the data to be transmitted :param c: Binary vector representing the data to be transmitted
:param SNR: Signal-to-Noise-Ratio in dB :param SNR: Signal-to-Noise-Ratio in dB

View File

@ -1,4 +1,5 @@
"""This file contains utility functions relating to tests and simulations of the decoders.""" """This file contains utility functions relating to tests and simulations of
the decoders."""
import json import json
import pandas as pd import pandas as pd
import numpy as np import numpy as np
@ -23,15 +24,18 @@ def count_bit_errors(d: np.array, d_hat: np.array) -> int:
# TODO: Write unit tests # TODO: Write unit tests
# TODO: Create generic Simulator Interface which should be implemented for
# specific applications
class Simulator: class Simulator:
"""Class allowing for saving of simulations state. """Class allowing for saving of simulations state.
Given a list of decoders, this class allows for simulating the Bit-Error-Rates of each decoder Given a list of decoders, this class allows for simulating the
for various SNRs. Bit-Error-Rates of each decoder for various SNRs.
The functionality implemented by this class could be achieved by a bunch of loops and a function. The functionality implemented by this class could be achieved by a bunch
However, storing the state of the simulation as member variables allows for pausing and resuming of loops and a function. However, storing the state of the simulation as
the simulation at a later time. member variables allows for pausing and resuming the simulation at a
later time.
""" """
def __init__(self, n: int, k: int, def __init__(self, n: int, k: int,
@ -44,7 +48,8 @@ class Simulator:
:param k: Number of bits in a dataword :param k: Number of bits in a dataword
:param decoders: Sequence of decoders to test :param decoders: Sequence of decoders to test
:param SNRs: Sequence of SNRs for which the BERs should be calculated :param SNRs: Sequence of SNRs for which the BERs should be calculated
:param target_frame_errors: Number of frame errors after which to stop the simulation :param target_frame_errors: Number of frame errors after which to
stop the simulation
""" """
# Simulation parameters # Simulation parameters
@ -76,20 +81,26 @@ class Simulator:
def _create_pbars(self): def _create_pbars(self):
self._overall_pbar = tqdm(total=len(self._decoders), self._overall_pbar = tqdm(total=len(self._decoders),
desc="Calculating the answer to life, the universe and everything", desc="Calculating the answer to life, "
"the universe and everything",
leave=False, leave=False,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}") bar_format="{l_bar}{bar}| {n_fmt}/{"
"total_fmt}")
decoder = self._decoders[self._current_decoder_index] decoder = self._decoders[self._current_decoder_index]
self._decoder_pbar = tqdm(total=len(self._SNRs), self._decoder_pbar = tqdm(total=len(self._SNRs),
desc=f"Calculating BERs for {decoder.__class__.__name__}", desc=f"Calculatin"
f"g BERs"
f" for {decoder.__class__.__name__}",
leave=False, leave=False,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}") bar_format="{l_bar}{bar}| {n_fmt}/{"
"total_fmt}")
self._snr_pbar = tqdm(total=self._target_frame_errors, self._snr_pbar = tqdm(total=self._target_frame_errors,
desc=f"Simulating for SNR = {self._SNRs[0]} dB", desc=f"Simulating for SNR = {self._SNRs[0]} dB",
leave=False, leave=False,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]") bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} "
"[{elapsed}<{remaining}]")
def __getstate__(self) -> typing.Dict: def __getstate__(self) -> typing.Dict:
"""Custom serialization function called by the 'pickle' module """Custom serialization function called by the 'pickle' module
@ -105,7 +116,8 @@ class Simulator:
"""Custom deserialization function called by the 'pickle' module """Custom deserialization function called by the 'pickle' module
when loading a previously saved simulation when loading a previously saved simulation
:param state: Dictionary storing the serialized version of an object of this class :param state: Dictionary storing the serialized version of an object
of this class
""" """
self.__dict__.update(state) self.__dict__.update(state)
@ -135,7 +147,8 @@ class Simulator:
def _update_statistics(self, bit_errors: int) -> None: def _update_statistics(self, bit_errors: int) -> None:
"""Update the statistics of the simulator. """Update the statistics of the simulator.
:param bit_errors: Number of bit errors that occurred during the last transmission :param bit_errors: Number of bit errors that occurred during the
last transmission
""" """
self._curr_num_iterations += 1 self._curr_num_iterations += 1
@ -153,7 +166,8 @@ class Simulator:
""" """
if self._curr_num_frame_errors >= self._target_frame_errors: if self._curr_num_frame_errors >= self._target_frame_errors:
self._BERs[self._current_decoder_index] \ self._BERs[self._current_decoder_index] \
.append(self._curr_num_bit_errors / (self._curr_num_iterations * self._n)) .append(self._curr_num_bit_errors / (
self._curr_num_iterations * self._n))
self._curr_num_frame_errors = 0 self._curr_num_frame_errors = 0
self._curr_num_bit_errors = 0 self._curr_num_bit_errors = 0
@ -163,7 +177,9 @@ class Simulator:
self._current_SNRs_index += 1 self._current_SNRs_index += 1
self._snr_pbar.reset() self._snr_pbar.reset()
self._snr_pbar.set_description(f"Simulating for SNR = {self._SNRs[self._current_SNRs_index]} dB") self._snr_pbar.set_description(
f"Simulating for SNR = "
f"{self._SNRs[self._current_SNRs_index]} dB")
self._decoder_pbar.update(1) self._decoder_pbar.update(1)
else: else:
if self._current_decoder_index < len(self._decoders) - 1: if self._current_decoder_index < len(self._decoders) - 1:
@ -173,7 +189,8 @@ class Simulator:
self._decoder_pbar.reset() self._decoder_pbar.reset()
decoder = self._decoders[self._current_decoder_index] decoder = self._decoders[self._current_decoder_index]
self._decoder_pbar.set_description(f"Calculating BERs for {decoder.__class__.__name__}") self._decoder_pbar.set_description(
f"Calculating BERs for {decoder.__class__.__name__}")
self._overall_pbar.update(1) self._overall_pbar.update(1)
else: else:
self._sim_running = False self._sim_running = False
@ -202,19 +219,23 @@ class Simulator:
def get_current_results(self) -> pd.DataFrame: def get_current_results(self) -> pd.DataFrame:
"""Get the current results. """Get the current results.
If the simulation has not yet completed, the BERs which have not yet been calculated are set to 0. If the simulation has not yet completed, the BERs which have not yet
been calculated are set to 0.
:return: pandas Dataframe with the columns ["SNR", "BER_1", "BER_2", ...] :return: pandas Dataframe with the columns ["SNR", "BER_1", "BER_2",
...]
""" """
data = {"SNR": np.array(self._SNRs)} data = {"SNR": np.array(self._SNRs)}
# If the BERs of a decoder have not been calculated for all SNRs, # If the BERs of a decoder have not been calculated for all SNRs,
# fill the rest up with zeros to match the length of the 'SNRs' array # fill the rest up with zeros to match the length of the 'SNRs' array
for i, decoder_BER_list in enumerate(self._BERs): for i, decoder_BER_list in enumerate(self._BERs):
padded = np.pad(decoder_BER_list, (0, len(self._SNRs) - len(decoder_BER_list))) padded = np.pad(decoder_BER_list,
(0, len(self._SNRs) - len(decoder_BER_list)))
data[f"BER_{i}"] = padded data[f"BER_{i}"] = padded
# If the BERs have not been calculated for all decoders, fill up the BERs list # If the BERs have not been calculated for all decoders, fill up the
# BERs list
# with zero-vectors to match the length of the 'decoders' list # with zero-vectors to match the length of the 'decoders' list
for i in range(len(self._decoders), len(self._BERs)): for i in range(len(self._decoders), len(self._BERs)):
data[f"BER_{i}"] = np.zeros(len(self._SNRs)) data[f"BER_{i}"] = np.zeros(len(self._SNRs))
@ -224,7 +245,9 @@ class Simulator:
# TODO: Fix typing.Any or Simulator # TODO: Fix typing.Any or Simulator
class SimulationDeSerializer: class SimulationDeSerializer:
"""Class responsible for file management, de- and serialization of Simulator objects.""" """Class responsible for file management, de- and serialization of
Simulator objects."""
def __init__(self, save_dir: str, results_dir: str): def __init__(self, save_dir: str, results_dir: str):
self._save_dir = save_dir self._save_dir = save_dir
self._results_dir = results_dir self._results_dir = results_dir
@ -260,7 +283,8 @@ class SimulationDeSerializer:
os.remove(self._get_savefile_path(sim_name)) os.remove(self._get_savefile_path(sim_name))
# os.remove(self._get_metadata_path(sim_name)) # os.remove(self._get_metadata_path(sim_name))
def save_state(self, simulator: typing.Any, sim_name: str, metadata: typing.Dict) -> None: def save_state(self, simulator: typing.Any, sim_name: str,
metadata: typing.Dict) -> None:
"""Save the state of a currently running simulation. """Save the state of a currently running simulation.
:param simulator: Simulator object :param simulator: Simulator object
@ -268,14 +292,16 @@ class SimulationDeSerializer:
:param metadata: Metadata to be saved besides the actual state :param metadata: Metadata to be saved besides the actual state
""" """
# Save metadata # Save metadata
with open(self._get_metadata_path(sim_name), 'w+', encoding='utf-8') as f: with open(self._get_metadata_path(sim_name), 'w+',
encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=4) json.dump(metadata, f, ensure_ascii=False, indent=4)
# Save simulation state # Save simulation state
with open(self._get_savefile_path(sim_name), "wb") as file: with open(self._get_savefile_path(sim_name), "wb") as file:
pickle.dump(simulator, file) pickle.dump(simulator, file)
def read_state(self, sim_name: str) -> typing.Tuple[typing.Any, typing.Dict]: def read_state(self, sim_name: str) -> typing.Tuple[
typing.Any, typing.Dict]:
"""Read the saved state of a paused simulation. """Read the saved state of a paused simulation.
:param sim_name: Name of the simulation :param sim_name: Name of the simulation
@ -285,7 +311,8 @@ class SimulationDeSerializer:
simulator = None simulator = None
# Read metadata # Read metadata
with open(self._get_metadata_path(sim_name), 'r', encoding='utf-8') as f: with open(self._get_metadata_path(sim_name), 'r',
encoding='utf-8') as f:
metadata = json.load(f) metadata = json.load(f)
# Read simulation state # Read simulation state
@ -295,29 +322,35 @@ class SimulationDeSerializer:
return simulator, metadata return simulator, metadata
# TODO: Is the simulator object actually necessary here? # TODO: Is the simulator object actually necessary here?
def save_results(self, simulator: typing.Any, sim_name: str, metadata: typing.Dict) -> None: def save_results(self, simulator: typing.Any, sim_name: str,
metadata: typing.Dict) -> None:
"""Save simulation results to file. """Save simulation results to file.
:param simulator: Simulator object. Used to obtain the data :param simulator: Simulator object. Used to obtain the data
:param sim_name: Name of the simulation. Determines the filename :param sim_name: Name of the simulation. Determines the filename
:param metadata: Metadata to be saved besides the actual simulation results :param metadata: Metadata to be saved besides the actual simulation
results
""" """
# Save metadata # Save metadata
with open(self._get_metadata_path(sim_name), 'w+', encoding='utf-8') as f: with open(self._get_metadata_path(sim_name), 'w+',
encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=4) json.dump(metadata, f, ensure_ascii=False, indent=4)
# Save results # Save results
df = simulator.get_current_results() df = simulator.get_current_results()
df.to_csv(self._get_results_path(sim_name)) df.to_csv(self._get_results_path(sim_name))
def read_results(self, sim_name: str) -> typing.Tuple[pd.DataFrame, typing.Dict]: def read_results(self, sim_name: str) -> typing.Tuple[
pd.DataFrame, typing.Dict]:
"""Read simulation results from file. """Read simulation results from file.
:param sim_name: Name of the simulation. :param sim_name: Name of the simulation.
:return: Tuple of the form (data, metadata), where data is a pandas dataframe and metadata is a dict :return: Tuple of the form (data, metadata), where data is a pandas
dataframe and metadata is a dict
""" """
# Read metadata # Read metadata
with open(self._get_metadata_path(sim_name), 'r', encoding='utf-8') as f: with open(self._get_metadata_path(sim_name), 'r',
encoding='utf-8') as f:
metadata = json.load(f) metadata = json.load(f)
# Read results # Read results
@ -329,8 +362,9 @@ class SimulationDeSerializer:
# TODO: Fix typing.Any or Simulator # TODO: Fix typing.Any or Simulator
# TODO: Autosave simulation every so often # TODO: Autosave simulation every so often
class SimulationManager: class SimulationManager:
"""This class only contains functions relating to stopping and restarting of simulations """This class only contains functions relating to stopping and
(and storing of the simulation state in a file, to be resumed at a later date). restarting of simulations (and storing of the simulation state in a
file, to be resumed at a later date).
All actual work is outsourced to a provided simulator class. All actual work is outsourced to a provided simulator class.
""" """
@ -338,8 +372,10 @@ class SimulationManager:
def __init__(self, save_dir: str, results_dir: str): def __init__(self, save_dir: str, results_dir: str):
"""Construct a SimulationManager object. """Construct a SimulationManager object.
:param save_dir: Directory in which the simulation state of a paused simulation should be stored :param save_dir: Directory in which the simulation state of a paused
:param results_dir: Directory in which the results of the simulation should be stored simulation should be stored
:param results_dir: Directory in which the results of the simulation
should be stored
""" """
self._de_serializer = SimulationDeSerializer(save_dir, results_dir) self._de_serializer = SimulationDeSerializer(save_dir, results_dir)
@ -357,14 +393,16 @@ class SimulationManager:
and (self._sim_name is not None) \ and (self._sim_name is not None) \
and (self._metadata is not None) and (self._metadata is not None)
def configure_simulation(self, simulator: typing.Any, name: str, column_labels: typing.Sequence[str]) -> None: def configure_simulation(self, simulator: typing.Any, name: str,
column_labels: typing.Sequence[str]) -> None:
"""Configure a new simulation.""" """Configure a new simulation."""
self._simulator = simulator self._simulator = simulator
self._sim_name = name self._sim_name = name
self._metadata["labels"] = column_labels self._metadata["labels"] = column_labels
def unfinished_simulation_present(self, sim_name: str) -> bool: def unfinished_simulation_present(self, sim_name: str) -> bool:
"""Check whether the savefile of a previously unfinished simulation is present.""" """Check whether the savefile of a previously unfinished simulation
is present."""
return self._de_serializer.unfinished_sim_present(sim_name) return self._de_serializer.unfinished_sim_present(sim_name)
def load_unfinished(self, sim_name: str) -> None: def load_unfinished(self, sim_name: str) -> None:
@ -375,16 +413,20 @@ class SimulationManager:
assert self.unfinished_simulation_present(sim_name) assert self.unfinished_simulation_present(sim_name)
self._sim_name = sim_name self._sim_name = sim_name
self._simulator, self._metadata = self._de_serializer.read_state(sim_name) self._simulator, self._metadata = self._de_serializer.read_state(
sim_name)
self._de_serializer.remove_unfinished_sim(sim_name) self._de_serializer.remove_unfinished_sim(sim_name)
def _exit_gracefully(self, *args) -> None: def _exit_gracefully(self, *args) -> None:
"""Handler called when the program is interrupted. Pauses and saves the currently running simulation.""" """Handler called when the program is interrupted. Pauses and saves
the currently running simulation."""
if self._sim_configured(): if self._sim_configured():
self._simulator.stop() self._simulator.stop()
self._de_serializer.save_state(self._simulator, self._sim_name, self._metadata) self._de_serializer.save_state(self._simulator, self._sim_name,
self._de_serializer.save_results(self._simulator, self._sim_name, self._metadata) self._metadata)
self._de_serializer.save_results(self._simulator, self._sim_name,
self._metadata)
exit() exit()
@ -393,4 +435,5 @@ class SimulationManager:
assert self._sim_configured() assert self._sim_configured()
self._simulator.start() self._simulator.start()
self._de_serializer.save_results(self._simulator, self._sim_name, self._metadata) self._de_serializer.save_results(self._simulator, self._sim_name,
self._metadata)

View File

@ -7,8 +7,8 @@ import math
def _get_num_rows(num_graphs: int, num_cols: int) -> int: def _get_num_rows(num_graphs: int, num_cols: int) -> int:
"""Get the minimum number of rows needed to show a certain number of graphs, """Get the minimum number of rows needed to show a certain number of
given a certain number of columns. graphs, given a certain number of columns.
:param num_graphs: Number of graphs :param num_graphs: Number of graphs
:param num_cols: Number of columns :param num_cols: Number of columns
@ -19,17 +19,21 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int:
# TODO: Handle number of graphs not nicely fitting into rows and columns # TODO: Handle number of graphs not nicely fitting into rows and columns
def plot_BERs(title: str, def plot_BERs(title: str,
data: typing.Sequence[typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]], data: typing.Sequence[
typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]],
num_cols: int = 3) -> plt.figure: num_cols: int = 3) -> plt.figure:
"""This function creates a matplotlib figure containing a number of plots. """This function creates a matplotlib figure containing a number of plots.
The plots created are logarithmic and the scaling is adjusted to be sensible for BER plots. The plots created are logarithmic and the scaling is adjusted to be
sensible for BER plots.
:param title: Title of the figure :param title: Title of the figure
:param data: Sequence of tuples. Each tuple corresponds to a new plot and :param data: Sequence of tuples. Each tuple corresponds to a new plot and
is of the following form: [graph_title, pd.Dataframe, [line_label_1, line_label2, ...]]. is of the following form: [graph_title, pd.Dataframe, [line_label_1,
Each dataframe is assumed to have an "SNR" column that is used as the x axis. line_label2, ...]]. Each dataframe is assumed to have an "SNR" column
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure that is used as the x axis.
:param num_cols: Number of columns in which the graphs should be
arranged in the resulting figure
:return: Matplotlib figure :return: Matplotlib figure
""" """
# Determine layout and create figure # Determine layout and create figure
@ -37,7 +41,9 @@ def plot_BERs(title: str,
num_graphs = len(data) num_graphs = len(data)
num_rows = _get_num_rows(num_graphs, num_cols) num_rows = _get_num_rows(num_graphs, num_cols)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*4, num_rows*4), squeeze=False) fig, axes = plt.subplots(num_rows, num_cols,
figsize=(num_cols * 4, num_rows * 4),
squeeze=False)
fig.suptitle(title) fig.suptitle(title)
fig.subplots_adjust(left=0.1, fig.subplots_adjust(left=0.1,
@ -47,12 +53,14 @@ def plot_BERs(title: str,
wspace=0.3, wspace=0.3,
hspace=0.4) hspace=0.4)
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array axes = list(chain.from_iterable(axes))[
:num_graphs] # Flatten the 2d axes array
# Populate axes # Populate axes
for axis, (graph_title, df, labels) in zip(axes, data): for axis, (graph_title, df, labels) in zip(axes, data):
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"] column_names = [column for column in df.columns.values.tolist() if
not column == "SNR"]
for column, label in zip(column_names, labels): for column, label in zip(column_names, labels):
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=label) sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=label)