diff --git a/sw/cpp/src/proximal.h b/sw/cpp/src/proximal.h index 65d6994..7676905 100644 --- a/sw/cpp/src/proximal.h +++ b/sw/cpp/src/proximal.h @@ -67,10 +67,12 @@ public: * This function assumes a BPSK modulated signal and an AWGN channel. * @param y Vector of received values. (y = x + w, where 'x' is element of * [-1, 1]^n and 'w' is noise) - * @return Most probably sent codeword (element of [0, 1]^n). If decoding - * fails, the returned value is 'None' + * @return \b std::pair of the form (x_hat, num_iter), x_hat is the most + * probably sent codeword and num_iter is the number of iterations that were + * performed. If the parity check fails and no valid codeword is reached, + * num_iter is -1 */ - std::pair>, int> + std::pair, int> decode(const Eigen::Ref>& y) { if (y.size() != mH.cols()) throw std::runtime_error("Length of vector must match H matrix"); @@ -95,12 +97,19 @@ public: } } - return {std::nullopt, mK}; + return {x_hat, -1}; } - /// Private members are not private in order to make the class easily - /// picklable - // private: + /** + * @brief Get the values of all member variables necessary to recreate an + * exact copy of this class. Used for pickling + * @return \b std::tuple + */ + auto get_decoder_state() const { + return std::tuple(mK, mOmega, mGamma, mEta, mH); + } + +private: const int mK; const double mOmega; const double mGamma; diff --git a/sw/cpp/src/python_interface.cpp b/sw/cpp/src/python_interface.cpp index f374bcf..d46040b 100644 --- a/sw/cpp/src/python_interface.cpp +++ b/sw/cpp/src/python_interface.cpp @@ -17,7 +17,13 @@ using namespace pybind11::literals; .def("decode", &ProximalDecoder::decode, "x"_a.noconvert()) \ .def(py::pickle( \ [](const ProximalDecoder& a) { \ - return py::make_tuple(a.mH, a.mK, a.mOmega, a.mGamma, a.mEta); \ + MatrixiR H; \ + int K; \ + double omega; \ + double gamma; \ + double eta; \ + std::tie(H, K, omega, gamma, eta) = a.get_decoder_state(); \ + return py::make_tuple(H, K, omega, gamma, eta); \ }, \ [](py::tuple t) { \ return ProximalDecoder{ \ diff --git a/sw/plot_heatmaps.py b/sw/plot_heatmaps.py index 5919a60..c596b35 100644 --- a/sw/plot_heatmaps.py +++ b/sw/plot_heatmaps.py @@ -18,39 +18,39 @@ def format_yticks(previous_yticks): def main(): sns.set_theme() - # titles = [ - # "$n=7$, $m=4$", - # "$n=31$, $m=20$", - # "$n=31$, $m=5$", - # "$n=96$, $m=48$", - # "$n=204$, $m=102$", - # "$n=408$, $m=204$", - # ] - # - # filenames = [ - # "sim_results/2d_dec_fails_w_log_k_lin_bch_7_4alist.csv", - # "sim_results/2d_dec_fails_w_log_k_lin_bch_31_11alist.csv", - # "sim_results/2d_dec_fails_w_log_k_lin_bch_31_26alist.csv", - # "sim_results/2d_dec_fails_w_log_k_lin_963965alist.csv", - # "sim_results/2d_dec_fails_w_log_k_lin_2043486alist.csv", - # "sim_results/2d_dec_fails_w_log_k_lin_40833844alist.csv", - # ] - titles = [ "$n=7$, $m=4$", "$n=31$, $m=20$", "$n=31$, $m=5$", - "$n=96$, $m=48$" + "$n=96$, $m=48$", + "$n=204$, $m=102$", + "$n=408$, $m=204$", ] filenames = [ - "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_bch_7_4alist.csv", - "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_bch_31_11alist.csv", - "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_bch_31_26alist.csv", - "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_963965alist.csv", + "sim_results/2d_dec_fails_w_log_k_lin_bch_7_4alist.csv", + "sim_results/2d_dec_fails_w_log_k_lin_bch_31_11alist.csv", + "sim_results/2d_dec_fails_w_log_k_lin_bch_31_26alist.csv", + "sim_results/2d_dec_fails_w_log_k_lin_963965alist.csv", + "sim_results/2d_dec_fails_w_log_k_lin_2043486alist.csv", + "sim_results/2d_dec_fails_w_log_k_lin_40833844alist.csv", ] - fig, axes = plt.subplots(2, 2, squeeze=False) + # titles = [ + # "$n=7$, $m=4$", + # "$n=31$, $m=20$", + # "$n=31$, $m=5$", + # "$n=96$, $m=48$" + # ] + # + # filenames = [ + # "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_bch_7_4alist.csv", + # "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_bch_31_11alist.csv", + # "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_bch_31_26alist.csv", + # "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_963965alist.csv", + # ] + + fig, axes = plt.subplots(2, 3, squeeze=False) fig.suptitle("SNR = 3dB") diff --git a/sw/simulate_2d_dec_fails.py b/sw/simulate_2d_dec_fails.py index c38d501..7b26e8d 100644 --- a/sw/simulate_2d_dec_fails.py +++ b/sw/simulate_2d_dec_fails.py @@ -23,7 +23,7 @@ def task_func(params): x = noise.add_awgn(x_bpsk, SNR, n, k) x_hat, num_iter = decoder.decode(x) - if x_hat is None: + if num_iter == -1: dec_fails += 1 return dec_fails / num_iterations @@ -80,7 +80,7 @@ def main(): H_file = "BCH_31_26.alist" SNR = 3 - num_iterations = 10000 + num_iterations = 1000 omegas = np.logspace(-0.3, -2.82, 40) Ks = np.ceil(np.linspace(10 ** 1.3, 10 ** 2.3, 40)).astype('int32') diff --git a/sw/simulate_BER_curve.py b/sw/simulate_BER_curve.py index e69de29..6816d09 100644 --- a/sw/simulate_BER_curve.py +++ b/sw/simulate_BER_curve.py @@ -0,0 +1,145 @@ +import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +import signal +from timeit import default_timer +from tqdm import tqdm + +from utility import codes, noise, misc +from utility.simulation.simulators import GenericMultithreadedSimulator + +# from cpp_modules.cpp_decoders import ProximalDecoder +from cpp_modules.cpp_decoders import ProximalDecoder_7_3 as ProximalDecoder + + +def count_bit_errors(d: np.array, d_hat: np.array) -> int: + return np.sum(d != d_hat) + + +def task_func(params): + signal.signal(signal.SIGINT, signal.SIG_IGN) + + decoder, max_iterations, SNR, n, k = params + c = np.zeros(n) + x_bpsk = c + 1 + + total_bit_errors = 0 + total_frame_errors = 0 + dec_fails = 0 + + num_iterations = 0 + + for i in range(max_iterations): + x = noise.add_awgn(x_bpsk, SNR, n, k) + x_hat, k_max = decoder.decode(x) + + bit_errors = count_bit_errors(x_hat, c) + if bit_errors > 0: + total_bit_errors += bit_errors + total_frame_errors += 1 + + num_iterations += 1 + + if k_max == -1: + dec_fails += 1 + + if total_frame_errors > 4000: + break + + BER = total_bit_errors / (num_iterations * n) + FER = total_frame_errors / num_iterations + DFR = dec_fails / (num_iterations + dec_fails) + + return BER, FER, DFR, num_iterations + + +def simulate(H_file, SNRs, max_iterations, omega, K, gammas): + sim = GenericMultithreadedSimulator() + + # Define fixed simulation params + + H = codes.read_alist_file(f"res/{H_file}") + n_min_k, n = H.shape + k = n - n_min_k + + # Define params different for each task + + params = {} + for i, SNR in enumerate(SNRs): + for j, gamma in enumerate(gammas): + decoder = ProximalDecoder(H=H.astype('int32'), K=K, omega=omega, + gamma=gamma) + params[f"{i}_{j}"] = (decoder, max_iterations, SNR, n, k) + + # Set up simulation + + sim.task_params = params + sim.task_func = task_func + + sim.start_or_continue() + + return sim.get_current_results() + + +def reformat_data(results, SNRs, gammas): + data = {"BER": np.zeros(3 * 10), "FER": np.zeros(3 * 10), + "DFR": np.zeros(3 * 10), "gamma": np.zeros(3 * 10), + "SNR": np.zeros(3 * 10), "num_iter": np.zeros(3 * 10)} + + for i, (key, (BER, FER, DFR, num_iter)) in enumerate(results.items()): + i_SNR, i_gamma = key.split('_') + data["BER"][i] = BER + data["FER"][i] = FER + data["DFR"][i] = DFR + data["num_iter"][i] = num_iter + data["SNR"][i] = SNRs[int(i_SNR)] + data["gamma"][i] = gammas[int(i_gamma)] + + print(pd.DataFrame(data)) + return pd.DataFrame(data) + + +def main(): + # Set up simulation params + + sim_name = "BER_FER_DFR" + + # H_file = "96.3.965.alist" + # H_file = "204.3.486.alist" + # H_file = "204.55.187.alist" + # H_file = "408.33.844.alist" + H_file = "BCH_7_4.alist" + # H_file = "BCH_31_11.alist" + # H_file = "BCH_31_26.alist" + SNRs = np.arange(1, 6, 0.5) + + max_iterations = 10000 + # omega = 0.005 + # K = 60 + omega = 0.05 + K = 60 + gammas = [0.15, 0.01, 0.05] + + # Run simulation + + start_time = default_timer() + results = simulate(H_file, SNRs, max_iterations, omega, K, gammas) + end_time = default_timer() + + print(f"duration: {end_time - start_time}") + + df = reformat_data(results, SNRs, gammas) + + df.to_csv( + f"sim_results/2d_dec_fails_{sim_name}_{misc.slugify(H_file)}.csv") + + sns.set_theme() + ax = sns.lineplot(data=df, x="SNR", y="BER", hue="gamma") + ax.set_yscale('log') + ax.set_ylim((5e-5, 2e-0)) + plt.show() + + +if __name__ == "__main__": + main()