diff --git a/sw/cpp/CMakeLists.txt b/sw/cpp/CMakeLists.txt index 4014e74..d7d8baa 100644 --- a/sw/cpp/CMakeLists.txt +++ b/sw/cpp/CMakeLists.txt @@ -26,7 +26,7 @@ find_package(OpenMP REQUIRED) #add_compile_options(-ffast-math) -pybind11_add_module(cpp_decoders src/cpp_decoders.cpp) +pybind11_add_module(cpp_decoders src/python_interface.cpp) target_link_libraries(cpp_decoders PRIVATE Eigen3::Eigen OpenMP::OpenMP_CXX) set(INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../cpp_modules) diff --git a/sw/cpp/src/cpp_decoders.cpp b/sw/cpp/src/cpp_decoders.cpp deleted file mode 100644 index 39a5c27..0000000 --- a/sw/cpp/src/cpp_decoders.cpp +++ /dev/null @@ -1,136 +0,0 @@ -#include -#include -#include -#include - -#include -#include -#include -#include - - -namespace py11 = pybind11; -using namespace pybind11::literals; - - -using MatrixXiR = - Eigen::Matrix; - -using MatrixXdR = - Eigen::Matrix; - - -class ProximalDecoder { -public: - ProximalDecoder(const Eigen::Ref& H, int K, double omega, - double gamma, double eta) - : mK(K), mOmega(omega), mGamma(gamma), mEta(eta), mH(H), - mH_zero_indices(find_zero(H)) { - - Eigen::setNbThreads(8); - } - - std::pair, int> - decode(const Eigen::Ref& y) { - if (y.size() != mH.cols()) - throw std::runtime_error("Length of vector must match H matrix"); - - Eigen::RowVectorXd s = Eigen::RowVectorXd::Zero(mH.cols()); - Eigen::RowVectorXi x_hat; - Eigen::RowVectorXd r; - - for (std::size_t i = 0; i < mK; ++i) { - r = s - mOmega * L_awgn(s, y); - - s = projection(r - mGamma * grad_H(r)); - - x_hat = s.unaryExpr([](double d) { - uint64_t bits = std::bit_cast(d); - // Return the sign bit: 1 for negative, 0 for positive - return (bits >> 63); - }).cast(); - - if (check_parity(x_hat)) { - return {x_hat, i + 1}; - } - } - - return {std::nullopt, mK}; - } - - // private: - const int mK; - const double mOmega; - const double mGamma; - const double mEta; - - const MatrixXiR mH; - const std::vector mH_zero_indices; - - - static Eigen::RowVectorXd L_awgn(const Eigen::RowVectorXd& s, - const Eigen::RowVectorXd& y) { - return s.array() - y.array(); - } - - static std::vector find_zero(MatrixXiR mat) { - std::vector indices; - - for (Eigen::Index i = 0; i < mat.size(); ++i) - if (mat(i) == 0) indices.push_back(i); - - return indices; - } - - Eigen::RowVectorXd grad_H(const Eigen::RowVectorXd& x) { - MatrixXdR A_prod_matrix = x.replicate(mH.rows(), 1); - - for (const auto& index : mH_zero_indices) - A_prod_matrix(index) = 1; - MatrixXdR A_prods = A_prod_matrix.rowwise().prod(); - - - Eigen::RowVectorXd B_sums = - (A_prods.array().pow(2) - A_prods.array()).matrix().transpose(); - B_sums = B_sums * mH.cast(); - - Eigen::RowVectorXd result = 4 * (x.array().pow(2) - 1) * x.array() + - (2 * x.array().inverse()) * B_sums.array(); - - return result; - } - - bool check_parity(const Eigen::RowVectorXi& x_hat) { - Eigen::RowVectorXi syndrome = - (mH * x_hat.transpose()).unaryExpr([](int i) { return i % 2; }); - - return !(syndrome.count() > 0); - } - - Eigen::RowVectorXd projection(const Eigen::RowVectorXd& v) { - return v.cwiseMin(mEta).cwiseMax(-mEta); - } -}; - - -PYBIND11_MODULE(cpp_decoders, proximal) { - proximal.doc() = "Proximal decoder"; - - pybind11::class_(proximal, "ProximalDecoder") - .def(pybind11::init(), - "H"_a.noconvert(), "K"_a = 100, "omega"_a = 0.0002, - "gamma"_a = .05, "eta"_a = 1.5) - .def("decode", &ProximalDecoder::decode, "x"_a.noconvert()) - .def(pybind11::pickle( - [](const ProximalDecoder& a) { // dump - return pybind11::make_tuple(a.mH, a.mK, a.mOmega, a.mGamma, - a.mEta); - }, - [](pybind11::tuple t) { // load - return ProximalDecoder{t[0].cast(), t[1].cast(), - t[2].cast(), t[3].cast(), - t[4].cast()}; - })); - - pybind11::register_exception(proximal, "CppException"); -} \ No newline at end of file diff --git a/sw/cpp/src/proximal.h b/sw/cpp/src/proximal.h new file mode 100644 index 0000000..809c4da --- /dev/null +++ b/sw/cpp/src/proximal.h @@ -0,0 +1,181 @@ +#pragma once + +#define EIGEN_STACK_ALLOCATION_LIMIT 524288 + +#include +#include +#include +#include + +#include +#include + + +/* + * + * Using declarations + * + */ + + +template +using MatrixiR = Eigen::Matrix; + +template +using MatrixdR = Eigen::Matrix; + +template +using RowVectori = Eigen::RowVector; + +template +using RowVectord = Eigen::RowVector; + + +/* + * + * Proximal decoder implementation + * + */ + + +/** + * @brief Class implementing the Proximal Decoding algorithm. See "Proximal + * Decoding for LDPC Codes" by Tadashi Wadayama, and Satoshi Takabe. + * @tparam t_m Number of rows of the H Matrix + * @tparam t_n Number of columns of the H Matrix + */ +template +class ProximalDecoder { +public: + /** + * @brief Constructor + * @param H Parity-Check Matrix + * @param K Number of iterations to run while decoding + * @param omega Step size + * @param gamma Positive constant. Arises in the approximation of the prior + * PDF + * @param eta Positive constant slightly larger than one. See 3.2, p. 3 + */ + ProximalDecoder(const Eigen::Ref>& H, int K, + double omega, double gamma, double eta) + : mK(K), mOmega(omega), mGamma(gamma), mEta(eta), mH(H), + mH_zero_indices(find_zero(H)) { + + Eigen::setNbThreads(8); + } + + /** + * @brief Decode a received signal. The algorithm is detailed in 3.2, p.3. + * This function assumes a BPSK modulated signal and an AWGN channel. + * @param y Vector of received values. (y = x + w, where 'x' is element of + * [-1, 1]^n and 'w' is noise) + * @return Most probably sent codeword (element of [0, 1]^n). If decoding + * fails, the returned value is 'None' + */ + std::pair>, int> + decode(const Eigen::Ref>& y) { + if (y.size() != mH.cols()) + throw std::runtime_error("Length of vector must match H matrix"); + + RowVectord s = RowVectord::Zero(t_n); + RowVectori x_hat; + RowVectord r; + + for (std::size_t i = 0; i < mK; ++i) { + r = s - mOmega * L_awgn(s, y); + + s = projection(r - mGamma * grad_H(r)); + + x_hat = s.unaryExpr([](double d) { + uint64_t bits = std::bit_cast(d); + // Return the sign bit: 1 for negative, 0 for positive + return (bits >> 63); + }).template cast(); + + if (check_parity(x_hat)) { + return {x_hat, i + 1}; + } + } + + return {std::nullopt, mK}; + } + + /// Private members are not private in order to make the class easily + /// picklable + // private: + const int mK; + const double mOmega; + const double mGamma; + const double mEta; + + const MatrixiR mH; + const std::vector mH_zero_indices; + + + /** + * Variation of the negative log-likelihood for the special case of AWGN + * noise. See 4.1, p. 4. + */ + static Eigen::RowVectorXd L_awgn(const RowVectord& s, + const RowVectord& y) { + return s.array() - y.array(); + } + + /** + * @brief Find all indices of a matrix, where the corresponding value is + * zero + * @return \b std::vector of indices + */ + static std::vector find_zero(MatrixiR mat) { + std::vector indices; + + for (Eigen::Index i = 0; i < mat.size(); ++i) + if (mat(i) == 0) indices.push_back(i); + + return indices; + } + + /** + * Gradient of the code-constraint polynomial. See 2.3, p. 2. + */ + RowVectord grad_H(const RowVectord& x) { + MatrixdR A_prod_matrix = x.replicate(t_m, 1); + + for (const auto& index : mH_zero_indices) + A_prod_matrix(index) = 1; + RowVectord A_prods = A_prod_matrix.rowwise().prod(); + + RowVectord B_terms = + (A_prods.array().pow(2) - A_prods.array()).matrix().transpose(); + + RowVectord B_sums = B_terms * mH.template cast(); + + RowVectord result = 4 * (x.array().pow(2) - 1) * x.array() + + (2 * x.array().inverse()) * B_sums.array(); + + return result; + } + + /** + * Perform a parity check for a given codeword. + * @param x_hat: codeword to be checked (element of [0, 1]^n) + * @return \b True if the parity check passes, i.e. the codeword is valid. + * False otherwise + */ + bool check_parity(const RowVectori& x_hat) { + RowVectori syndrome = + (mH * x_hat.transpose()).unaryExpr([](int i) { return i % 2; }); + + return !(syndrome.count() > 0); + } + + /** + * Project a vector onto [-eta, eta]^n in order to avoid numerical + * instability. Detailed in 3.2, p. 3 (Equation (15)). + * @param v Vector to project + * @return v clipped to [-eta, eta]^n + */ + RowVectord projection(const RowVectord& v) { + return v.cwiseMin(mEta).cwiseMax(-mEta); + } +}; diff --git a/sw/cpp/src/python_interface.cpp b/sw/cpp/src/python_interface.cpp new file mode 100644 index 0000000..a9654af --- /dev/null +++ b/sw/cpp/src/python_interface.cpp @@ -0,0 +1,38 @@ +#include "proximal.h" + +#include + + +namespace py = pybind11; +using namespace pybind11::literals; + + +#define DEF_PROXIMAL_DECODER(name, m, n) \ + py::class_>(proximal, name) \ + .def(py::init, int, double, double, double>(), \ + "H"_a.noconvert(), "K"_a = 100, "omega"_a = 0.0002, \ + "gamma"_a = .05, "eta"_a = 1.5) \ + .def("decode", &ProximalDecoder::decode, "x"_a.noconvert()) \ + .def(py::pickle( \ + [](const ProximalDecoder& a) { \ + return py::make_tuple(a.mH, a.mK, a.mOmega, a.mGamma, a.mEta); \ + }, \ + [](py::tuple t) { \ + return ProximalDecoder{ \ + t[0].cast>(), t[1].cast(), \ + t[2].cast(), t[3].cast(), \ + t[4].cast()}; \ + })); + + +PYBIND11_MODULE(cpp_decoders, proximal) { + proximal.doc() = "Proximal decoder"; + + DEF_PROXIMAL_DECODER("ProximalDecoder_7_4", 4, 7) + DEF_PROXIMAL_DECODER("ProximalDecoder_96_48", 48, 96) + DEF_PROXIMAL_DECODER("ProximalDecoder_204_102", 102, 204) + DEF_PROXIMAL_DECODER("ProximalDecoder_Dynamic", Eigen::Dynamic, + Eigen::Dynamic) + + py::register_exception(proximal, "CppException"); +} \ No newline at end of file