Fixed usages of x, x_hat, y
This commit is contained in:
parent
70bbe08bc4
commit
23e318609c
@ -48,22 +48,22 @@ class ProximalDecoder:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _projection(self, x):
|
def _projection(self, v):
|
||||||
"""Project a vector onto [-eta, eta]^n in order to avoid numerical instability.
|
"""Project a vector onto [-eta, eta]^n in order to avoid numerical instability.
|
||||||
Detailed in 3.2, p. 3 (Equation (15)).
|
Detailed in 3.2, p. 3 (Equation (15)).
|
||||||
|
|
||||||
:param x: Vector to project
|
:param v: Vector to project
|
||||||
:return: x clipped to [-eta, eta]^n
|
:return: x clipped to [-eta, eta]^n
|
||||||
"""
|
"""
|
||||||
return np.clip(x, -self._eta, self._eta)
|
return np.clip(v, -self._eta, self._eta)
|
||||||
|
|
||||||
def _check_parity(self, y_hat: np.array) -> bool:
|
def _check_parity(self, x_hat: np.array) -> bool:
|
||||||
"""Perform a parity check for a given codeword.
|
"""Perform a parity check for a given codeword.
|
||||||
|
|
||||||
:param y_hat: codeword to be checked (element of [0, 1]^n)
|
:param x_hat: codeword to be checked (element of [0, 1]^n)
|
||||||
:return: True if the parity check passes, i.e. the codeword is valid. False otherwise
|
:return: True if the parity check passes, i.e. the codeword is valid. False otherwise
|
||||||
"""
|
"""
|
||||||
syndrome = np.dot(self._H, y_hat) % 2
|
syndrome = np.dot(self._H, x_hat) % 2
|
||||||
return not np.any(syndrome)
|
return not np.any(syndrome)
|
||||||
|
|
||||||
def decode(self, y: np.array) -> np.array:
|
def decode(self, y: np.array) -> np.array:
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import numpy as np
|
|||||||
|
|
||||||
def _parse_alist_header(header):
|
def _parse_alist_header(header):
|
||||||
size = header.split()
|
size = header.split()
|
||||||
return int(size[0,]), int(size[1])
|
return int(size[0]), int(size[1])
|
||||||
|
|
||||||
|
|
||||||
def read_alist_file(filename):
|
def read_alist_file(filename):
|
||||||
|
|||||||
@ -54,9 +54,9 @@ def test_decoder(x: np.array,
|
|||||||
|
|
||||||
y = noise.add_awgn(x_bpsk, SNR, signal_amp=np.sqrt(2))
|
y = noise.add_awgn(x_bpsk, SNR, signal_amp=np.sqrt(2))
|
||||||
|
|
||||||
y_hat = decoder.decode(y)
|
x_hat = decoder.decode(y)
|
||||||
|
|
||||||
total_bit_errors += count_bit_errors(x, y_hat)
|
total_bit_errors += count_bit_errors(x, x_hat)
|
||||||
total_bits += x.size
|
total_bits += x.size
|
||||||
|
|
||||||
if total_bit_errors >= target_bit_errors:
|
if total_bit_errors >= target_bit_errors:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user