Reformatted all code
This commit is contained in:
parent
3c27a0cc18
commit
8913246600
@ -1 +1,2 @@
|
||||
"""This package contains a number of different decoder implementations for LDPC codes."""
|
||||
"""This package contains a number of different decoder implementations for
|
||||
LDPC codes."""
|
||||
|
||||
@ -3,9 +3,9 @@ import itertools
|
||||
|
||||
|
||||
class MLDecoder:
|
||||
"""This class naively implements a soft decision decoder. The decoder calculates
|
||||
the correlation between the received signal and each codeword and then chooses the
|
||||
one with the largest correlation.
|
||||
"""This class naively implements a soft decision decoder. The decoder
|
||||
calculates the correlation between the received signal and each codeword
|
||||
and then chooses the one with the largest correlation.
|
||||
"""
|
||||
|
||||
def __init__(self, G: np.array, H: np.array):
|
||||
@ -17,13 +17,15 @@ class MLDecoder:
|
||||
self._G = G
|
||||
self._H = H
|
||||
self._datawords, self._codewords = self._gen_codewords()
|
||||
self._codewords_bpsk = 1 - 2 * self._codewords # The codewords, but mapped to [-1, 1]^n
|
||||
|
||||
# The codewords, but mapped to [-1, 1]^n
|
||||
self._codewords_bpsk = 1 - 2 * self._codewords
|
||||
|
||||
def _gen_codewords(self) -> np.array:
|
||||
"""Generate a list of all possible codewords.
|
||||
|
||||
:return: Numpy array of the form [[codeword_1], [codeword_2], ...]
|
||||
(Each generated codeword is an element of [0, 1]^n)
|
||||
(Each generated codeword is an element of [0, 1]^n)
|
||||
"""
|
||||
k, n = self._G.shape
|
||||
|
||||
@ -41,8 +43,8 @@ class MLDecoder:
|
||||
|
||||
This function assumes a BPSK modulated signal.
|
||||
|
||||
:param y: Vector of received values. (y = x + w, where 'x' is element of [-1, 1]^n
|
||||
and 'w' is noise)
|
||||
:param y: Vector of received values. (y = x + w, where 'x' is
|
||||
element of [-1, 1]^n and 'w' is noise)
|
||||
:return: Most probably sent codeword (element of [0, 1]^k)
|
||||
"""
|
||||
correlations = np.dot(self._codewords_bpsk, y)
|
||||
|
||||
@ -2,7 +2,8 @@ import numpy as np
|
||||
|
||||
|
||||
class ProximalDecoder:
|
||||
"""Class implementing the Proximal Decoding algorithm. See "Proximal Decoding for LDPC Codes"
|
||||
"""Class implementing the Proximal Decoding algorithm. See "Proximal
|
||||
Decoding for LDPC Codes"
|
||||
by Tadashi Wadayama, and Satoshi Takabe.
|
||||
"""
|
||||
|
||||
@ -13,7 +14,8 @@ class ProximalDecoder:
|
||||
:param H: Parity Check Matrix
|
||||
:param K: Max number of iterations to perform when decoding
|
||||
:param omega: Step size for the gradient descent process
|
||||
:param gamma: Positive constant. Arises in the approximation of the prior PDF
|
||||
:param gamma: Positive constant. Arises in the approximation of the
|
||||
prior PDF
|
||||
:param eta: Positive constant slightly larger than one. See 3.2, p. 3
|
||||
"""
|
||||
self._H = H
|
||||
@ -27,8 +29,8 @@ class ProximalDecoder:
|
||||
|
||||
@staticmethod
|
||||
def _L_awgn(s: np.array, y: np.array) -> np.array:
|
||||
"""Variation of the negative log-likelihood for the special case of AWGN noise.
|
||||
See 4.1, p. 4.
|
||||
"""Variation of the negative log-likelihood for the special case of
|
||||
AWGN noise. See 4.1, p. 4.
|
||||
"""
|
||||
return s - y
|
||||
|
||||
@ -41,15 +43,15 @@ class ProximalDecoder:
|
||||
|
||||
# Calculate gradient
|
||||
|
||||
sums = np.dot(A_prods**2 - A_prods, self._H)
|
||||
sums = np.dot(A_prods ** 2 - A_prods, self._H)
|
||||
|
||||
result = 4 * (x**2 - 1) * x + (2 / x) * sums
|
||||
result = 4 * (x ** 2 - 1) * x + (2 / x) * sums
|
||||
|
||||
return result
|
||||
|
||||
def _projection(self, v):
|
||||
"""Project a vector onto [-eta, eta]^n in order to avoid numerical instability.
|
||||
Detailed in 3.2, p. 3 (Equation (15)).
|
||||
"""Project a vector onto [-eta, eta]^n in order to avoid numerical
|
||||
instability. Detailed in 3.2, p. 3 (Equation (15)).
|
||||
|
||||
:param v: Vector to project
|
||||
:return: x clipped to [-eta, eta]^n
|
||||
@ -60,7 +62,8 @@ class ProximalDecoder:
|
||||
"""Perform a parity check for a given codeword.
|
||||
|
||||
:param x_hat: codeword to be checked (element of [0, 1]^n)
|
||||
:return: True if the parity check passes, i.e. the codeword is valid. False otherwise
|
||||
:return: True if the parity check passes, i.e. the codeword is
|
||||
valid. False otherwise
|
||||
"""
|
||||
syndrome = np.dot(self._H, x_hat) % 2
|
||||
return not np.any(syndrome)
|
||||
@ -70,8 +73,8 @@ class ProximalDecoder:
|
||||
|
||||
This function assumes a BPSK modulated signal and an AWGN channel.
|
||||
|
||||
:param y: Vector of received values. (y = x + w, where 'x' is element of [-1, 1]^n
|
||||
and 'w' is noise)
|
||||
:param y: Vector of received values. (y = x + w, where 'x' is
|
||||
element of [-1, 1]^n and 'w' is noise)
|
||||
:return: Most probably sent codeword (element of [0, 1]^n)
|
||||
"""
|
||||
s = np.zeros(self._n)
|
||||
@ -83,7 +86,9 @@ class ProximalDecoder:
|
||||
s = self._projection(s) # Equation (15)
|
||||
|
||||
x_hat = np.sign(s)
|
||||
x_hat = (x_hat == -1) * 1 # Map the codeword from [-1, 1]^n to [0, 1]^n
|
||||
|
||||
# Map the codeword from [ -1, 1]^n to [0, 1]^n
|
||||
x_hat = (x_hat == -1) * 1
|
||||
|
||||
if self._check_parity(x_hat):
|
||||
break
|
||||
|
||||
15
sw/main.py
15
sw/main.py
@ -1,7 +1,5 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from decoders import proximal, maximum_likelihood
|
||||
from utility import simulation, codes
|
||||
@ -12,7 +10,8 @@ def main():
|
||||
|
||||
sim_name = "test"
|
||||
|
||||
sim_mgr = simulation.SimulationManager(results_dir="sim_results", save_dir="sim_saves")
|
||||
sim_mgr = simulation.SimulationManager(results_dir="sim_results",
|
||||
save_dir="sim_saves")
|
||||
|
||||
if sim_mgr.unfinished_simulation_present(sim_name):
|
||||
print("Found unfinished simulation. Picking up where it was left of")
|
||||
@ -28,8 +27,7 @@ def main():
|
||||
# H = codes.read_alist_file("res/999.111.3.5543.alist")
|
||||
# H = codes.read_alist_file("res/999.111.3.5565.alist")
|
||||
H = codes.read_alist_file("res/816.1A4.845.alist")
|
||||
k = 272
|
||||
n = 816
|
||||
k, n = H.shape
|
||||
|
||||
decoders = [
|
||||
proximal.ProximalDecoder(H, gamma=0.01),
|
||||
@ -43,9 +41,12 @@ def main():
|
||||
"proximal $\\gamma = 0.15$"
|
||||
]
|
||||
|
||||
sim = simulation.Simulator(n=n, k=k, decoders=decoders, target_frame_errors=3, SNRs=np.arange(1, 6, 0.5))
|
||||
sim = simulation.Simulator(n=n, k=k, decoders=decoders,
|
||||
target_frame_errors=3,
|
||||
SNRs=np.arange(1, 6, 0.5))
|
||||
|
||||
sim_mgr.configure_simulation(simulator=sim, name=sim_name, column_labels=labels)
|
||||
sim_mgr.configure_simulation(simulator=sim, name=sim_name,
|
||||
column_labels=labels)
|
||||
sim_mgr.simulate()
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ from utility import visualization, simulation
|
||||
|
||||
|
||||
def plot_results():
|
||||
|
||||
sim_names = [
|
||||
"96.3965",
|
||||
"204.3.486",
|
||||
@ -19,7 +18,8 @@ def plot_results():
|
||||
"PEGReg252x504"
|
||||
]
|
||||
|
||||
deserializer = simulation.SimulationDeSerializer(save_dir="sim_saves", results_dir="sim_results")
|
||||
deserializer = simulation.SimulationDeSerializer(save_dir="sim_saves",
|
||||
results_dir="sim_results")
|
||||
|
||||
data = []
|
||||
for sim_name in sim_names:
|
||||
@ -33,8 +33,9 @@ def plot_results():
|
||||
data.append(graph_tuple)
|
||||
|
||||
sns.set_theme()
|
||||
fig = visualization.plot_BERs(title="Bit-Error-Rates of proximal decoder for different codes",
|
||||
data=data, num_cols=4)
|
||||
fig = visualization.plot_BERs(
|
||||
title="Bit-Error-Rates of proximal decoder for different codes",
|
||||
data=data, num_cols=4)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
@ -1 +1,2 @@
|
||||
"""This package contains various utilities that can be used in combination with the decoders."""
|
||||
"""This package contains various utilities that can be used in combination
|
||||
with the decoders."""
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
"""This file Helper functions for generating an H matrix from alist data.
|
||||
Code from https://github.com/gnuradio/gnuradio/blob/master/gr-fec/python/fec/LDPC/Generate_LDPC_matrix_functions.py
|
||||
Code from https://github.com/gnuradio/gnuradio/blob/master/gr-fec/python/fec
|
||||
/LDPC/Generate_LDPC_matrix_functions.py
|
||||
"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -44,6 +44,8 @@ def read_alist_file(filename):
|
||||
#
|
||||
|
||||
|
||||
# @formatter:off
|
||||
|
||||
Gs = {'Hamming_7_4': np.array([[1, 0, 0, 0, 0, 1, 1],
|
||||
[0, 1, 0, 0, 1, 0, 1],
|
||||
[0, 0, 1, 0, 1, 1, 0],
|
||||
@ -219,6 +221,8 @@ Gs = {'Hamming_7_4': np.array([[1, 0, 0, 0, 0, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1]])
|
||||
}
|
||||
|
||||
# @formatter:on
|
||||
|
||||
|
||||
#
|
||||
# Utilities for systematic codes
|
||||
|
||||
@ -4,7 +4,8 @@ import re
|
||||
|
||||
def slugify(value, allow_unicode=False):
|
||||
"""
|
||||
Taken from https://github.com/django/django/blob/master/django/utils/text.py
|
||||
Taken from https://github.com/django/django/blob/master/django/utils
|
||||
/text.py
|
||||
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
||||
dashes to single dashes. Remove characters that aren't alphanumerics,
|
||||
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
||||
@ -14,6 +15,8 @@ def slugify(value, allow_unicode=False):
|
||||
if allow_unicode:
|
||||
value = unicodedata.normalize('NFKC', value)
|
||||
else:
|
||||
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
||||
value = unicodedata.normalize('NFKD', value).encode('ascii',
|
||||
'ignore').decode(
|
||||
'ascii')
|
||||
value = re.sub(r'[^\w\s-]', '', value.lower())
|
||||
return re.sub(r'[-\s]+', '-', value).strip('-_')
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
"""Utility functions relating to noise and SNR calculations."""
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_noise_variance_from_SNR(SNR: float, n: int, k: int) -> float:
|
||||
"""Calculate the variance of the noise from an SNR and the signal amplitude.
|
||||
"""Calculate the variance of the noise from an SNR and the signal
|
||||
amplitude.
|
||||
|
||||
:param SNR: Signal-to-Noise-Ratio in dB (E_b/N_0)
|
||||
:param n: Length of a codeword of the used code
|
||||
@ -13,14 +13,15 @@ def get_noise_variance_from_SNR(SNR: float, n: int, k: int) -> float:
|
||||
:return: Variance of the noise
|
||||
"""
|
||||
SNR_linear = 10 ** (SNR / 10)
|
||||
variance = 1 / (2 * (k/n) * SNR_linear)
|
||||
variance = 1 / (2 * (k / n) * SNR_linear)
|
||||
|
||||
return variance
|
||||
|
||||
|
||||
def add_awgn(c: np.array, SNR: float, n: int, k: int) -> np.array:
|
||||
"""Add Additive White Gaussian Noise to a data vector. As this function adds random noise to
|
||||
the input, the output changes, even if it is called multiple times with the same input.
|
||||
"""Add Additive White Gaussian Noise to a data vector. As this function
|
||||
adds random noise to the input, the output changes, even if it is called
|
||||
multiple times with the same input.
|
||||
|
||||
:param c: Binary vector representing the data to be transmitted
|
||||
:param SNR: Signal-to-Noise-Ratio in dB
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""This file contains utility functions relating to tests and simulations of the decoders."""
|
||||
"""This file contains utility functions relating to tests and simulations of
|
||||
the decoders."""
|
||||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@ -23,15 +24,18 @@ def count_bit_errors(d: np.array, d_hat: np.array) -> int:
|
||||
|
||||
|
||||
# TODO: Write unit tests
|
||||
# TODO: Create generic Simulator Interface which should be implemented for
|
||||
# specific applications
|
||||
class Simulator:
|
||||
"""Class allowing for saving of simulations state.
|
||||
|
||||
Given a list of decoders, this class allows for simulating the Bit-Error-Rates of each decoder
|
||||
for various SNRs.
|
||||
Given a list of decoders, this class allows for simulating the
|
||||
Bit-Error-Rates of each decoder for various SNRs.
|
||||
|
||||
The functionality implemented by this class could be achieved by a bunch of loops and a function.
|
||||
However, storing the state of the simulation as member variables allows for pausing and resuming
|
||||
the simulation at a later time.
|
||||
The functionality implemented by this class could be achieved by a bunch
|
||||
of loops and a function. However, storing the state of the simulation as
|
||||
member variables allows for pausing and resuming the simulation at a
|
||||
later time.
|
||||
"""
|
||||
|
||||
def __init__(self, n: int, k: int,
|
||||
@ -44,7 +48,8 @@ class Simulator:
|
||||
:param k: Number of bits in a dataword
|
||||
:param decoders: Sequence of decoders to test
|
||||
:param SNRs: Sequence of SNRs for which the BERs should be calculated
|
||||
:param target_frame_errors: Number of frame errors after which to stop the simulation
|
||||
:param target_frame_errors: Number of frame errors after which to
|
||||
stop the simulation
|
||||
"""
|
||||
# Simulation parameters
|
||||
|
||||
@ -76,20 +81,26 @@ class Simulator:
|
||||
|
||||
def _create_pbars(self):
|
||||
self._overall_pbar = tqdm(total=len(self._decoders),
|
||||
desc="Calculating the answer to life, the universe and everything",
|
||||
desc="Calculating the answer to life, "
|
||||
"the universe and everything",
|
||||
leave=False,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}")
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{"
|
||||
"total_fmt}")
|
||||
|
||||
decoder = self._decoders[self._current_decoder_index]
|
||||
self._decoder_pbar = tqdm(total=len(self._SNRs),
|
||||
desc=f"Calculating BERs for {decoder.__class__.__name__}",
|
||||
desc=f"Calculatin"
|
||||
f"g BERs"
|
||||
f" for {decoder.__class__.__name__}",
|
||||
leave=False,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}")
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{"
|
||||
"total_fmt}")
|
||||
|
||||
self._snr_pbar = tqdm(total=self._target_frame_errors,
|
||||
desc=f"Simulating for SNR = {self._SNRs[0]} dB",
|
||||
leave=False,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]")
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} "
|
||||
"[{elapsed}<{remaining}]")
|
||||
|
||||
def __getstate__(self) -> typing.Dict:
|
||||
"""Custom serialization function called by the 'pickle' module
|
||||
@ -105,7 +116,8 @@ class Simulator:
|
||||
"""Custom deserialization function called by the 'pickle' module
|
||||
when loading a previously saved simulation
|
||||
|
||||
:param state: Dictionary storing the serialized version of an object of this class
|
||||
:param state: Dictionary storing the serialized version of an object
|
||||
of this class
|
||||
"""
|
||||
self.__dict__.update(state)
|
||||
|
||||
@ -135,7 +147,8 @@ class Simulator:
|
||||
def _update_statistics(self, bit_errors: int) -> None:
|
||||
"""Update the statistics of the simulator.
|
||||
|
||||
:param bit_errors: Number of bit errors that occurred during the last transmission
|
||||
:param bit_errors: Number of bit errors that occurred during the
|
||||
last transmission
|
||||
"""
|
||||
self._curr_num_iterations += 1
|
||||
|
||||
@ -153,7 +166,8 @@ class Simulator:
|
||||
"""
|
||||
if self._curr_num_frame_errors >= self._target_frame_errors:
|
||||
self._BERs[self._current_decoder_index] \
|
||||
.append(self._curr_num_bit_errors / (self._curr_num_iterations * self._n))
|
||||
.append(self._curr_num_bit_errors / (
|
||||
self._curr_num_iterations * self._n))
|
||||
|
||||
self._curr_num_frame_errors = 0
|
||||
self._curr_num_bit_errors = 0
|
||||
@ -163,7 +177,9 @@ class Simulator:
|
||||
self._current_SNRs_index += 1
|
||||
|
||||
self._snr_pbar.reset()
|
||||
self._snr_pbar.set_description(f"Simulating for SNR = {self._SNRs[self._current_SNRs_index]} dB")
|
||||
self._snr_pbar.set_description(
|
||||
f"Simulating for SNR = "
|
||||
f"{self._SNRs[self._current_SNRs_index]} dB")
|
||||
self._decoder_pbar.update(1)
|
||||
else:
|
||||
if self._current_decoder_index < len(self._decoders) - 1:
|
||||
@ -173,7 +189,8 @@ class Simulator:
|
||||
|
||||
self._decoder_pbar.reset()
|
||||
decoder = self._decoders[self._current_decoder_index]
|
||||
self._decoder_pbar.set_description(f"Calculating BERs for {decoder.__class__.__name__}")
|
||||
self._decoder_pbar.set_description(
|
||||
f"Calculating BERs for {decoder.__class__.__name__}")
|
||||
self._overall_pbar.update(1)
|
||||
else:
|
||||
self._sim_running = False
|
||||
@ -202,19 +219,23 @@ class Simulator:
|
||||
def get_current_results(self) -> pd.DataFrame:
|
||||
"""Get the current results.
|
||||
|
||||
If the simulation has not yet completed, the BERs which have not yet been calculated are set to 0.
|
||||
If the simulation has not yet completed, the BERs which have not yet
|
||||
been calculated are set to 0.
|
||||
|
||||
:return: pandas Dataframe with the columns ["SNR", "BER_1", "BER_2", ...]
|
||||
:return: pandas Dataframe with the columns ["SNR", "BER_1", "BER_2",
|
||||
...]
|
||||
"""
|
||||
data = {"SNR": np.array(self._SNRs)}
|
||||
|
||||
# If the BERs of a decoder have not been calculated for all SNRs,
|
||||
# fill the rest up with zeros to match the length of the 'SNRs' array
|
||||
for i, decoder_BER_list in enumerate(self._BERs):
|
||||
padded = np.pad(decoder_BER_list, (0, len(self._SNRs) - len(decoder_BER_list)))
|
||||
padded = np.pad(decoder_BER_list,
|
||||
(0, len(self._SNRs) - len(decoder_BER_list)))
|
||||
data[f"BER_{i}"] = padded
|
||||
|
||||
# If the BERs have not been calculated for all decoders, fill up the BERs list
|
||||
# If the BERs have not been calculated for all decoders, fill up the
|
||||
# BERs list
|
||||
# with zero-vectors to match the length of the 'decoders' list
|
||||
for i in range(len(self._decoders), len(self._BERs)):
|
||||
data[f"BER_{i}"] = np.zeros(len(self._SNRs))
|
||||
@ -224,7 +245,9 @@ class Simulator:
|
||||
|
||||
# TODO: Fix typing.Any or Simulator
|
||||
class SimulationDeSerializer:
|
||||
"""Class responsible for file management, de- and serialization of Simulator objects."""
|
||||
"""Class responsible for file management, de- and serialization of
|
||||
Simulator objects."""
|
||||
|
||||
def __init__(self, save_dir: str, results_dir: str):
|
||||
self._save_dir = save_dir
|
||||
self._results_dir = results_dir
|
||||
@ -260,7 +283,8 @@ class SimulationDeSerializer:
|
||||
os.remove(self._get_savefile_path(sim_name))
|
||||
# os.remove(self._get_metadata_path(sim_name))
|
||||
|
||||
def save_state(self, simulator: typing.Any, sim_name: str, metadata: typing.Dict) -> None:
|
||||
def save_state(self, simulator: typing.Any, sim_name: str,
|
||||
metadata: typing.Dict) -> None:
|
||||
"""Save the state of a currently running simulation.
|
||||
|
||||
:param simulator: Simulator object
|
||||
@ -268,14 +292,16 @@ class SimulationDeSerializer:
|
||||
:param metadata: Metadata to be saved besides the actual state
|
||||
"""
|
||||
# Save metadata
|
||||
with open(self._get_metadata_path(sim_name), 'w+', encoding='utf-8') as f:
|
||||
with open(self._get_metadata_path(sim_name), 'w+',
|
||||
encoding='utf-8') as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=4)
|
||||
|
||||
# Save simulation state
|
||||
with open(self._get_savefile_path(sim_name), "wb") as file:
|
||||
pickle.dump(simulator, file)
|
||||
|
||||
def read_state(self, sim_name: str) -> typing.Tuple[typing.Any, typing.Dict]:
|
||||
def read_state(self, sim_name: str) -> typing.Tuple[
|
||||
typing.Any, typing.Dict]:
|
||||
"""Read the saved state of a paused simulation.
|
||||
|
||||
:param sim_name: Name of the simulation
|
||||
@ -285,7 +311,8 @@ class SimulationDeSerializer:
|
||||
simulator = None
|
||||
|
||||
# Read metadata
|
||||
with open(self._get_metadata_path(sim_name), 'r', encoding='utf-8') as f:
|
||||
with open(self._get_metadata_path(sim_name), 'r',
|
||||
encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Read simulation state
|
||||
@ -295,29 +322,35 @@ class SimulationDeSerializer:
|
||||
return simulator, metadata
|
||||
|
||||
# TODO: Is the simulator object actually necessary here?
|
||||
def save_results(self, simulator: typing.Any, sim_name: str, metadata: typing.Dict) -> None:
|
||||
def save_results(self, simulator: typing.Any, sim_name: str,
|
||||
metadata: typing.Dict) -> None:
|
||||
"""Save simulation results to file.
|
||||
|
||||
:param simulator: Simulator object. Used to obtain the data
|
||||
:param sim_name: Name of the simulation. Determines the filename
|
||||
:param metadata: Metadata to be saved besides the actual simulation results
|
||||
:param metadata: Metadata to be saved besides the actual simulation
|
||||
results
|
||||
"""
|
||||
# Save metadata
|
||||
with open(self._get_metadata_path(sim_name), 'w+', encoding='utf-8') as f:
|
||||
with open(self._get_metadata_path(sim_name), 'w+',
|
||||
encoding='utf-8') as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=4)
|
||||
|
||||
# Save results
|
||||
df = simulator.get_current_results()
|
||||
df.to_csv(self._get_results_path(sim_name))
|
||||
|
||||
def read_results(self, sim_name: str) -> typing.Tuple[pd.DataFrame, typing.Dict]:
|
||||
def read_results(self, sim_name: str) -> typing.Tuple[
|
||||
pd.DataFrame, typing.Dict]:
|
||||
"""Read simulation results from file.
|
||||
|
||||
:param sim_name: Name of the simulation.
|
||||
:return: Tuple of the form (data, metadata), where data is a pandas dataframe and metadata is a dict
|
||||
:return: Tuple of the form (data, metadata), where data is a pandas
|
||||
dataframe and metadata is a dict
|
||||
"""
|
||||
# Read metadata
|
||||
with open(self._get_metadata_path(sim_name), 'r', encoding='utf-8') as f:
|
||||
with open(self._get_metadata_path(sim_name), 'r',
|
||||
encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Read results
|
||||
@ -329,8 +362,9 @@ class SimulationDeSerializer:
|
||||
# TODO: Fix typing.Any or Simulator
|
||||
# TODO: Autosave simulation every so often
|
||||
class SimulationManager:
|
||||
"""This class only contains functions relating to stopping and restarting of simulations
|
||||
(and storing of the simulation state in a file, to be resumed at a later date).
|
||||
"""This class only contains functions relating to stopping and
|
||||
restarting of simulations (and storing of the simulation state in a
|
||||
file, to be resumed at a later date).
|
||||
|
||||
All actual work is outsourced to a provided simulator class.
|
||||
"""
|
||||
@ -338,8 +372,10 @@ class SimulationManager:
|
||||
def __init__(self, save_dir: str, results_dir: str):
|
||||
"""Construct a SimulationManager object.
|
||||
|
||||
:param save_dir: Directory in which the simulation state of a paused simulation should be stored
|
||||
:param results_dir: Directory in which the results of the simulation should be stored
|
||||
:param save_dir: Directory in which the simulation state of a paused
|
||||
simulation should be stored
|
||||
:param results_dir: Directory in which the results of the simulation
|
||||
should be stored
|
||||
"""
|
||||
self._de_serializer = SimulationDeSerializer(save_dir, results_dir)
|
||||
|
||||
@ -357,14 +393,16 @@ class SimulationManager:
|
||||
and (self._sim_name is not None) \
|
||||
and (self._metadata is not None)
|
||||
|
||||
def configure_simulation(self, simulator: typing.Any, name: str, column_labels: typing.Sequence[str]) -> None:
|
||||
def configure_simulation(self, simulator: typing.Any, name: str,
|
||||
column_labels: typing.Sequence[str]) -> None:
|
||||
"""Configure a new simulation."""
|
||||
self._simulator = simulator
|
||||
self._sim_name = name
|
||||
self._metadata["labels"] = column_labels
|
||||
|
||||
def unfinished_simulation_present(self, sim_name: str) -> bool:
|
||||
"""Check whether the savefile of a previously unfinished simulation is present."""
|
||||
"""Check whether the savefile of a previously unfinished simulation
|
||||
is present."""
|
||||
return self._de_serializer.unfinished_sim_present(sim_name)
|
||||
|
||||
def load_unfinished(self, sim_name: str) -> None:
|
||||
@ -375,16 +413,20 @@ class SimulationManager:
|
||||
assert self.unfinished_simulation_present(sim_name)
|
||||
|
||||
self._sim_name = sim_name
|
||||
self._simulator, self._metadata = self._de_serializer.read_state(sim_name)
|
||||
self._simulator, self._metadata = self._de_serializer.read_state(
|
||||
sim_name)
|
||||
|
||||
self._de_serializer.remove_unfinished_sim(sim_name)
|
||||
|
||||
def _exit_gracefully(self, *args) -> None:
|
||||
"""Handler called when the program is interrupted. Pauses and saves the currently running simulation."""
|
||||
"""Handler called when the program is interrupted. Pauses and saves
|
||||
the currently running simulation."""
|
||||
if self._sim_configured():
|
||||
self._simulator.stop()
|
||||
self._de_serializer.save_state(self._simulator, self._sim_name, self._metadata)
|
||||
self._de_serializer.save_results(self._simulator, self._sim_name, self._metadata)
|
||||
self._de_serializer.save_state(self._simulator, self._sim_name,
|
||||
self._metadata)
|
||||
self._de_serializer.save_results(self._simulator, self._sim_name,
|
||||
self._metadata)
|
||||
|
||||
exit()
|
||||
|
||||
@ -393,4 +435,5 @@ class SimulationManager:
|
||||
assert self._sim_configured()
|
||||
|
||||
self._simulator.start()
|
||||
self._de_serializer.save_results(self._simulator, self._sim_name, self._metadata)
|
||||
self._de_serializer.save_results(self._simulator, self._sim_name,
|
||||
self._metadata)
|
||||
|
||||
@ -7,8 +7,8 @@ import math
|
||||
|
||||
|
||||
def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
||||
"""Get the minimum number of rows needed to show a certain number of graphs,
|
||||
given a certain number of columns.
|
||||
"""Get the minimum number of rows needed to show a certain number of
|
||||
graphs, given a certain number of columns.
|
||||
|
||||
:param num_graphs: Number of graphs
|
||||
:param num_cols: Number of columns
|
||||
@ -19,17 +19,21 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
||||
|
||||
# TODO: Handle number of graphs not nicely fitting into rows and columns
|
||||
def plot_BERs(title: str,
|
||||
data: typing.Sequence[typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]],
|
||||
data: typing.Sequence[
|
||||
typing.Tuple[str, pd.DataFrame, typing.Sequence[str]]],
|
||||
num_cols: int = 3) -> plt.figure:
|
||||
"""This function creates a matplotlib figure containing a number of plots.
|
||||
|
||||
The plots created are logarithmic and the scaling is adjusted to be sensible for BER plots.
|
||||
The plots created are logarithmic and the scaling is adjusted to be
|
||||
sensible for BER plots.
|
||||
|
||||
:param title: Title of the figure
|
||||
:param data: Sequence of tuples. Each tuple corresponds to a new plot and
|
||||
is of the following form: [graph_title, pd.Dataframe, [line_label_1, line_label2, ...]].
|
||||
Each dataframe is assumed to have an "SNR" column that is used as the x axis.
|
||||
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
|
||||
is of the following form: [graph_title, pd.Dataframe, [line_label_1,
|
||||
line_label2, ...]]. Each dataframe is assumed to have an "SNR" column
|
||||
that is used as the x axis.
|
||||
:param num_cols: Number of columns in which the graphs should be
|
||||
arranged in the resulting figure
|
||||
:return: Matplotlib figure
|
||||
"""
|
||||
# Determine layout and create figure
|
||||
@ -37,7 +41,9 @@ def plot_BERs(title: str,
|
||||
num_graphs = len(data)
|
||||
num_rows = _get_num_rows(num_graphs, num_cols)
|
||||
|
||||
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*4, num_rows*4), squeeze=False)
|
||||
fig, axes = plt.subplots(num_rows, num_cols,
|
||||
figsize=(num_cols * 4, num_rows * 4),
|
||||
squeeze=False)
|
||||
fig.suptitle(title)
|
||||
|
||||
fig.subplots_adjust(left=0.1,
|
||||
@ -47,12 +53,14 @@ def plot_BERs(title: str,
|
||||
wspace=0.3,
|
||||
hspace=0.4)
|
||||
|
||||
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array
|
||||
axes = list(chain.from_iterable(axes))[
|
||||
:num_graphs] # Flatten the 2d axes array
|
||||
|
||||
# Populate axes
|
||||
|
||||
for axis, (graph_title, df, labels) in zip(axes, data):
|
||||
column_names = [column for column in df.columns.values.tolist() if not column == "SNR"]
|
||||
column_names = [column for column in df.columns.values.tolist() if
|
||||
not column == "SNR"]
|
||||
|
||||
for column, label in zip(column_names, labels):
|
||||
sns.lineplot(ax=axis, data=df, x="SNR", y=column, label=label)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user