182 lines
5.3 KiB
C++
182 lines
5.3 KiB
C++
#pragma once
|
|
|
|
#define EIGEN_STACK_ALLOCATION_LIMIT 524288
|
|
|
|
#include <Eigen/Dense>
|
|
#include <bit>
|
|
#include <iostream>
|
|
#include <stdexcept>
|
|
|
|
#include <pybind11/eigen.h>
|
|
#include <pybind11/stl.h>
|
|
|
|
|
|
/*
|
|
*
|
|
* Using declarations
|
|
*
|
|
*/
|
|
|
|
|
|
template <int t_rows, int t_cols>
|
|
using MatrixiR = Eigen::Matrix<int, t_rows, t_cols, Eigen::RowMajor>;
|
|
|
|
template <int t_rows, int t_cols>
|
|
using MatrixdR = Eigen::Matrix<double, t_rows, t_cols, Eigen::RowMajor>;
|
|
|
|
template <int t_size>
|
|
using RowVectori = Eigen::RowVector<int, t_size>;
|
|
|
|
template <int t_size>
|
|
using RowVectord = Eigen::RowVector<double, t_size>;
|
|
|
|
|
|
/*
|
|
*
|
|
* Proximal decoder implementation
|
|
*
|
|
*/
|
|
|
|
|
|
/**
|
|
* @brief Class implementing the Proximal Decoding algorithm. See "Proximal
|
|
* Decoding for LDPC Codes" by Tadashi Wadayama, and Satoshi Takabe.
|
|
* @tparam t_m Number of rows of the H Matrix
|
|
* @tparam t_n Number of columns of the H Matrix
|
|
*/
|
|
template <int t_m, int t_n>
|
|
class ProximalDecoder {
|
|
public:
|
|
/**
|
|
* @brief Constructor
|
|
* @param H Parity-Check Matrix
|
|
* @param K Number of iterations to run while decoding
|
|
* @param omega Step size
|
|
* @param gamma Positive constant. Arises in the approximation of the prior
|
|
* PDF
|
|
* @param eta Positive constant slightly larger than one. See 3.2, p. 3
|
|
*/
|
|
ProximalDecoder(const Eigen::Ref<const MatrixiR<t_m, t_n>>& H, int K,
|
|
double omega, double gamma, double eta)
|
|
: mK(K), mOmega(omega), mGamma(gamma), mEta(eta), mH(H),
|
|
mH_zero_indices(find_zero(H)) {
|
|
|
|
Eigen::setNbThreads(8);
|
|
}
|
|
|
|
/**
|
|
* @brief Decode a received signal. The algorithm is detailed in 3.2, p.3.
|
|
* This function assumes a BPSK modulated signal and an AWGN channel.
|
|
* @param y Vector of received values. (y = x + w, where 'x' is element of
|
|
* [-1, 1]^n and 'w' is noise)
|
|
* @return Most probably sent codeword (element of [0, 1]^n). If decoding
|
|
* fails, the returned value is 'None'
|
|
*/
|
|
std::pair<std::optional<RowVectori<t_n>>, int>
|
|
decode(const Eigen::Ref<const RowVectord<t_n>>& y) {
|
|
if (y.size() != mH.cols())
|
|
throw std::runtime_error("Length of vector must match H matrix");
|
|
|
|
RowVectord<t_n> s = RowVectord<t_n>::Zero(t_n);
|
|
RowVectori<t_n> x_hat;
|
|
RowVectord<t_n> r;
|
|
|
|
for (std::size_t i = 0; i < mK; ++i) {
|
|
r = s - mOmega * L_awgn(s, y);
|
|
|
|
s = projection(r - mGamma * grad_H(r));
|
|
|
|
x_hat = s.unaryExpr([](double d) {
|
|
uint64_t bits = std::bit_cast<uint64_t>(d);
|
|
// Return the sign bit: 1 for negative, 0 for positive
|
|
return (bits >> 63);
|
|
}).template cast<int>();
|
|
|
|
if (check_parity(x_hat)) {
|
|
return {x_hat, i + 1};
|
|
}
|
|
}
|
|
|
|
return {std::nullopt, mK};
|
|
}
|
|
|
|
/// Private members are not private in order to make the class easily
|
|
/// picklable
|
|
// private:
|
|
const int mK;
|
|
const double mOmega;
|
|
const double mGamma;
|
|
const double mEta;
|
|
|
|
const MatrixiR<t_m, t_n> mH;
|
|
const std::vector<Eigen::Index> mH_zero_indices;
|
|
|
|
|
|
/**
|
|
* Variation of the negative log-likelihood for the special case of AWGN
|
|
* noise. See 4.1, p. 4.
|
|
*/
|
|
static Eigen::RowVectorXd L_awgn(const RowVectord<t_n>& s,
|
|
const RowVectord<t_n>& y) {
|
|
return s.array() - y.array();
|
|
}
|
|
|
|
/**
|
|
* @brief Find all indices of a matrix, where the corresponding value is
|
|
* zero
|
|
* @return \b std::vector of indices
|
|
*/
|
|
static std::vector<Eigen::Index> find_zero(MatrixiR<t_m, t_n> mat) {
|
|
std::vector<Eigen::Index> indices;
|
|
|
|
for (Eigen::Index i = 0; i < mat.size(); ++i)
|
|
if (mat(i) == 0) indices.push_back(i);
|
|
|
|
return indices;
|
|
}
|
|
|
|
/**
|
|
* Gradient of the code-constraint polynomial. See 2.3, p. 2.
|
|
*/
|
|
RowVectord<t_n> grad_H(const RowVectord<t_n>& x) {
|
|
MatrixdR<t_m, t_n> A_prod_matrix = x.replicate(t_m, 1);
|
|
|
|
for (const auto& index : mH_zero_indices)
|
|
A_prod_matrix(index) = 1;
|
|
RowVectord<t_m> A_prods = A_prod_matrix.rowwise().prod();
|
|
|
|
RowVectord<t_m> B_terms =
|
|
(A_prods.array().pow(2) - A_prods.array()).matrix().transpose();
|
|
|
|
RowVectord<t_n> B_sums = B_terms * mH.template cast<double>();
|
|
|
|
RowVectord<t_n> result = 4 * (x.array().pow(2) - 1) * x.array() +
|
|
(2 * x.array().inverse()) * B_sums.array();
|
|
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* Perform a parity check for a given codeword.
|
|
* @param x_hat: codeword to be checked (element of [0, 1]^n)
|
|
* @return \b True if the parity check passes, i.e. the codeword is valid.
|
|
* False otherwise
|
|
*/
|
|
bool check_parity(const RowVectori<t_n>& x_hat) {
|
|
RowVectori<t_m> syndrome =
|
|
(mH * x_hat.transpose()).unaryExpr([](int i) { return i % 2; });
|
|
|
|
return !(syndrome.count() > 0);
|
|
}
|
|
|
|
/**
|
|
* Project a vector onto [-eta, eta]^n in order to avoid numerical
|
|
* instability. Detailed in 3.2, p. 3 (Equation (15)).
|
|
* @param v Vector to project
|
|
* @return v clipped to [-eta, eta]^n
|
|
*/
|
|
RowVectord<t_n> projection(const RowVectord<t_n>& v) {
|
|
return v.cwiseMin(mEta).cwiseMax(-mEta);
|
|
}
|
|
};
|