Added simulate_gradient.py

This commit is contained in:
Andreas Tsouchlos 2022-12-07 14:13:14 +01:00
parent 765ba88773
commit 4736061ada

85
sw/simulate_gradient.py Normal file
View File

@ -0,0 +1,85 @@
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import signal
from timeit import default_timer
from tqdm import tqdm
from utility import codes, noise, misc
from utility.simulation.simulators import GenericMultithreadedSimulator
# from cpp_modules.cpp_decoders import ProximalDecoder
from cpp_modules.cpp_decoders import ProximalDecoder_204_102 as ProximalDecoder
def simulate(H_file, SNR, omega, K, gamma):
H = codes.read_alist_file(f"res/{H_file}")
n_min_k, n = H.shape
k = n - n_min_k
decoder = ProximalDecoder(H.astype('int32'), K=K, omega=omega, gamma=gamma)
c = np.zeros(n)
x_bpsk = (c + 1)
avg_grad_values = np.zeros(shape=(K, 2))
for i in range(1000):
x = noise.add_awgn(x_bpsk, SNR, n, k)
grad_values = decoder.get_gradient_values(x)
for j, (val_h, val_l) in enumerate(grad_values):
avg_grad_values[j, 0] += val_h
avg_grad_values[j, 1] += val_l
avg_grad_values = avg_grad_values / 1000
return avg_grad_values
def reformat_data(results):
return pd.DataFrame({"k": np.arange(0, results.size // 2, 1), "grad_h": results[:, 0], "grad_l": results[:, 1]})
def main():
# Set up simulation params
sim_name = "avg_grad_1dB"
# H_file = "96.3.965.alist"
H_file = "204.33.486.alist"
# H_file = "204.33.484.alist"
# H_file = "204.55.187.alist"
# H_file = "408.33.844.alist"
# H_file = "BCH_7_4.alist"
# H_file = "BCH_31_11.alist"
# H_file = "BCH_31_26.alist"
SNR = 1
omega = 0.05
K = 100
gamma = 0.05
# Run simulation
start_time = default_timer()
results = simulate(H_file, SNR, omega, K, gamma)
end_time = default_timer()
print(f"duration: {end_time - start_time}")
df = reformat_data(results)
df.to_csv(
f"sim_results/{sim_name}_{misc.slugify(H_file)}.csv", index=False)
sns.set_theme()
sns.lineplot(data=df, x="k", y="grad_h")
sns.lineplot(data=df, x="k", y="grad_l")
plt.show()
if __name__ == "__main__":
main()