#pragma once #define EIGEN_STACK_ALLOCATION_LIMIT 524288 #include #include #include #include #include #include /* * * Using declarations * */ template using MatrixiR = Eigen::Matrix; template using MatrixdR = Eigen::Matrix; template using RowVectori = Eigen::RowVector; template using RowVectord = Eigen::RowVector; /* * * Proximal decoder implementation * */ /** * @brief Class implementing the Proximal Decoding algorithm. See "Proximal * Decoding for LDPC Codes" by Tadashi Wadayama, and Satoshi Takabe. * @tparam t_m Number of rows of the H Matrix * @tparam t_n Number of columns of the H Matrix */ template class ProximalDecoder { public: /** * @brief Constructor * @param H Parity-Check Matrix * @param K Number of iterations to run while decoding * @param omega Step size * @param gamma Positive constant. Arises in the approximation of the prior * PDF * @param eta Positive constant slightly larger than one. See 3.2, p. 3 */ ProximalDecoder(const Eigen::Ref>& H, int K, double omega, double gamma, double eta) : mK(K), mOmega(omega), mGamma(gamma), mEta(eta), mH(H), mH_zero_indices(find_zero(H)) { Eigen::setNbThreads(8); } /** * @brief Decode a received signal. The algorithm is detailed in 3.2, p.3. * This function assumes a BPSK modulated signal and an AWGN channel. * @param y Vector of received values. (y = x + w, where 'x' is element of * [-1, 1]^n and 'w' is noise) * @return Most probably sent codeword (element of [0, 1]^n). If decoding * fails, the returned value is 'None' */ std::pair>, int> decode(const Eigen::Ref>& y) { if (y.size() != mH.cols()) throw std::runtime_error("Length of vector must match H matrix"); RowVectord s = RowVectord::Zero(t_n); RowVectori x_hat; RowVectord r; for (std::size_t i = 0; i < mK; ++i) { r = s - mOmega * L_awgn(s, y); s = projection(r - mGamma * grad_H(r)); x_hat = s.unaryExpr([](double d) { uint64_t bits = std::bit_cast(d); // Return the sign bit: 1 for negative, 0 for positive return (bits >> 63); }).template cast(); if (check_parity(x_hat)) { return {x_hat, i + 1}; } } return {std::nullopt, mK}; } /// Private members are not private in order to make the class easily /// picklable // private: const int mK; const double mOmega; const double mGamma; const double mEta; const MatrixiR mH; const std::vector mH_zero_indices; /** * Variation of the negative log-likelihood for the special case of AWGN * noise. See 4.1, p. 4. */ static Eigen::RowVectorXd L_awgn(const RowVectord& s, const RowVectord& y) { return s.array() - y.array(); } /** * @brief Find all indices of a matrix, where the corresponding value is * zero * @return \b std::vector of indices */ static std::vector find_zero(MatrixiR mat) { std::vector indices; for (Eigen::Index i = 0; i < mat.size(); ++i) if (mat(i) == 0) indices.push_back(i); return indices; } /** * Gradient of the code-constraint polynomial. See 2.3, p. 2. */ RowVectord grad_H(const RowVectord& x) { MatrixdR A_prod_matrix = x.replicate(t_m, 1); for (const auto& index : mH_zero_indices) A_prod_matrix(index) = 1; RowVectord A_prods = A_prod_matrix.rowwise().prod(); RowVectord B_terms = (A_prods.array().pow(2) - A_prods.array()).matrix().transpose(); RowVectord B_sums = B_terms * mH.template cast(); RowVectord result = 4 * (x.array().pow(2) - 1) * x.array() + (2 * x.array().inverse()) * B_sums.array(); return result; } /** * Perform a parity check for a given codeword. * @param x_hat: codeword to be checked (element of [0, 1]^n) * @return \b True if the parity check passes, i.e. the codeword is valid. * False otherwise */ bool check_parity(const RowVectori& x_hat) { RowVectori syndrome = (mH * x_hat.transpose()).unaryExpr([](int i) { return i % 2; }); return !(syndrome.count() > 0); } /** * Project a vector onto [-eta, eta]^n in order to avoid numerical * instability. Detailed in 3.2, p. 3 (Equation (15)). * @param v Vector to project * @return v clipped to [-eta, eta]^n */ RowVectord projection(const RowVectord& v) { return v.cwiseMin(mEta).cwiseMax(-mEta); } };