Fix simulate_error_rate.py

This commit is contained in:
Andreas Tsouchlos 2025-05-13 00:04:07 +02:00
parent ad274699a3
commit f12339f89f

View File

@ -29,17 +29,14 @@ class SimulationArgs:
def decode(tracker, y, H, args: SimulationArgs) -> np.ndarray: def decode(tracker, y, H, args: SimulationArgs) -> np.ndarray:
x_hat = np.mod(np.round(y), 2).astype('int32') x_hat = np.where(y >= 0, 0, 1)
s = np.concatenate([y, np.array([0])]) s = np.concatenate([y, np.array([0])])
for i in range(args.homotopy_iter): for i in range(args.homotopy_iter):
x_hat = np.mod(np.round(s[:-1]), 2).astype('int32') x_hat = np.where(s[:-1] >= 0, 0, 1)
if not np.any(np.mod(H @ x_hat, 2)): if not np.any(np.mod(H @ x_hat, 2)):
return x_hat return x_hat
# if s[-1] > 1.5:
# return x_hat
try: try:
s = tracker.step(s) s = tracker.step(s)
except: except:
@ -67,22 +64,28 @@ def simulate_error_rates_for_SNR(H, Eb_N0, args: SimulationArgs) -> typing.Tuple
for _ in tqdm(range(args.max_frames)): for _ in tqdm(range(args.max_frames)):
Eb_N0_lin = 10**(Eb_N0 / 10) Eb_N0_lin = 10**(Eb_N0 / 10)
N0 = 1 / (2 * k / n * Eb_N0_lin) N0 = 1 / (2 * k / n * Eb_N0_lin)
y = np.zeros(n) + np.sqrt(N0) * np.random.normal(size=n)
u = np.random.randint(2, size=k)
# u = np.zeros(shape=k).astype(np.int32)
c = np.array(GF(u) @ G)
x = 1 - 2*c
y = x + np.sqrt(N0) * np.random.normal(size=n)
homotopy.update_received(y) homotopy.update_received(y)
x_hat = decode(tracker, y, H, args) c_hat = decode(tracker, y, H, args)
if np.any(np.mod(H @ x_hat, 2)): if np.any(np.mod(H @ c_hat, 2)):
tracker.set_sigma(-1 * args.sigma) tracker.set_sigma(-1 * args.sigma)
x_hat = decode(tracker, y, H, args) c_hat = decode(tracker, y, H, args)
tracker.set_sigma(args.sigma) tracker.set_sigma(args.sigma)
if np.any(np.mod(H @ x_hat, 2)): if np.any(np.mod(H @ c_hat, 2)):
decoding_failures += 1 decoding_failures += 1
bit_errors += np.sum(x_hat != np.zeros(n)) bit_errors += np.sum(c_hat != c)
frame_errors += np.any(x_hat != np.zeros(n)) frame_errors += np.any(c_hat != c)
num_frames += 1 num_frames += 1
if frame_errors >= args.target_frame_errors: if frame_errors >= args.target_frame_errors:
@ -95,7 +98,7 @@ def simulate_error_rates_for_SNR(H, Eb_N0, args: SimulationArgs) -> typing.Tuple
return FER, BER, DFR, frame_errors return FER, BER, DFR, frame_errors
def simulate_error_rates(H, Eb_N0_list, args: SimulationArgs) -> typing.Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: def simulate_error_rates(H, Eb_N0_list, args: SimulationArgs) -> pd.DataFrame:
FERs = [] FERs = []
BERs = [] BERs = []
DFRs = [] DFRs = []
@ -106,8 +109,11 @@ def simulate_error_rates(H, Eb_N0_list, args: SimulationArgs) -> typing.Tuple[np
BERs.append(BER) BERs.append(BER)
DFRs.append(DFR) DFRs.append(DFR)
frame_errors_list.append(num_frames) frame_errors_list.append(num_frames)
print(pd.DataFrame({"SNR": Eb_N0_list[:len(FERs)], "FER": FERs, "BER": BERs,
"DFR": DFRs, "frame_errors": frame_errors_list}))
return np.array(FERs), np.array(BERs), np.array(DFRs), np.array(frame_errors_list) return pd.DataFrame({"SNR": Eb_N0_list, "FER": FERs, "BER": BERs,
"DFR": DFRs, "frame_errors": frame_errors_list})
def main(): def main():
@ -157,11 +163,8 @@ def main():
target_frame_errors=args.target_frame_errors) target_frame_errors=args.target_frame_errors)
SNRs = np.arange(args.snr[0], args.snr[1], args.snr[2]) SNRs = np.arange(args.snr[0], args.snr[1], args.snr[2])
FERs, BERs, DFRs, frame_errors_list = simulate_error_rates( df = simulate_error_rates(H, SNRs, simulation_args)
H, SNRs, simulation_args)
df = pd.DataFrame({"SNR": SNRs, "FER": FERs, "BER": BERs,
"DFR": DFRs, "frame_errors": frame_errors_list})
print(df) print(df)