Reimplemented gradient computation with more numpy builtin functions
This commit is contained in:
parent
b7c4b4e359
commit
97acaaafd3
@ -7,8 +7,6 @@ class ProximalDecoder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Is 'R' actually called 'decoding matrix'?
|
# TODO: Is 'R' actually called 'decoding matrix'?
|
||||||
# TODO: How large should eta be?
|
|
||||||
# TODO: How large should step_size be?
|
|
||||||
def __init__(self, H: np.array, R: np.array, K: int = 100, step_size: float = 0.1,
|
def __init__(self, H: np.array, R: np.array, K: int = 100, step_size: float = 0.1,
|
||||||
gamma: float = 0.05, eta: float = 1.5):
|
gamma: float = 0.05, eta: float = 1.5):
|
||||||
"""Construct a new ProximalDecoder Object.
|
"""Construct a new ProximalDecoder Object.
|
||||||
@ -48,24 +46,20 @@ class ProximalDecoder:
|
|||||||
# TODO: Is this correct?
|
# TODO: Is this correct?
|
||||||
def _grad_h(self, x: np.array) -> np.array:
|
def _grad_h(self, x: np.array) -> np.array:
|
||||||
"""Gradient of the code-constraint polynomial. See 2.3, p. 2."""
|
"""Gradient of the code-constraint polynomial. See 2.3, p. 2."""
|
||||||
# Calculate first term
|
# Pre-computations
|
||||||
|
|
||||||
result = 4 * (x**2 - 1) * x
|
k, _ = self._H.shape
|
||||||
|
|
||||||
# Calculate second term
|
A_prod_matrix = np.tile(x, (k, 1))
|
||||||
|
A_prods = np.prod(A_prod_matrix, axis=1, where=self._H > 0)
|
||||||
|
|
||||||
for k, x_k in enumerate(x):
|
# Calculate gradient
|
||||||
sum_result = 0
|
|
||||||
|
|
||||||
for i in self._B[k]:
|
sums = np.dot(A_prods**2 - A_prods, self._H)
|
||||||
prod = np.prod(x[self._A[i]])
|
|
||||||
sum_result += prod**2 - prod
|
|
||||||
|
|
||||||
term_2 = 2 / x_k * sum_result
|
result = 4 * (x**2 - 1) * x + (2 / x) * sums
|
||||||
|
|
||||||
result[k] += term_2
|
return result
|
||||||
|
|
||||||
return np.array(result)
|
|
||||||
|
|
||||||
# TODO: Is the 'projection onto [-eta, eta]' actually just clipping?
|
# TODO: Is the 'projection onto [-eta, eta]' actually just clipping?
|
||||||
def _projection(self, x):
|
def _projection(self, x):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user