diff --git a/sw/cpp/src/proximal.h b/sw/cpp/src/proximal.h index b582285..42e47b9 100644 --- a/sw/cpp/src/proximal.h +++ b/sw/cpp/src/proximal.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -100,6 +101,94 @@ public: return {x_hat, -1}; } + /** + * @brief Decode a received signal an measure error value (x - sign(x_hat)) + * @param x Transmitted word + * @param y Received signal + * @return \b std::vector of error values. Each element corresponds to one + * iteration of the algorithm + */ + std::vector + get_error_values(const Eigen::Ref>& x, + const Eigen::Ref>& y) { + + if (y.size() != mH.cols()) + throw std::runtime_error("Length of vector must match H matrix"); + + std::vector error_values; + error_values.reserve(mK); + + RowVectord s = RowVectord::Zero(t_n); + RowVectori x_hat; + RowVectord r; + + for (std::size_t i = 0; i < mK; ++i) { + r = s - mOmega * L_awgn(s, y); + + s = projection(r - mGamma * grad_H(r)); + + x_hat = s.unaryExpr([](double d) { + uint64_t bits = std::bit_cast(d); + // Return the sign bit: 1 for negative, 0 for positive + return (bits >> 63); + }).template cast(); + + RowVectord x_hat_bpsk = + -1 * ((2 * x_hat.template cast()).array() - 1).matrix(); + error_values.push_back( + (x.template cast() - x_hat_bpsk).norm()); + + if (check_parity(x_hat)) { + break; + } + } + + return error_values; + } + + /** + * @brief Decode a received signal and measure the norm of the two gradients + * at each iteration + * @param y + * @return \b std::vector of \b std::pair of gradient values. Each element corresponds to + * one iteration. Result is of the form [(grad_H_1, grad_L_1), ...] + */ + std::vector> + get_gradient_values(const Eigen::Ref>& y) { + + if (y.size() != mH.cols()) + throw std::runtime_error("Length of vector must match H matrix"); + + std::vector> gradient_values; + gradient_values.reserve(mK); + + RowVectord s = RowVectord::Zero(t_n); + RowVectori x_hat; + RowVectord r; + + for (std::size_t i = 0; i < mK; ++i) { + RowVectord gradl = L_awgn(s, y); + r = s - mOmega * gradl; + + RowVectord gradh = grad_H(r); + s = projection(r - mGamma * gradh); + + gradient_values.push_back({gradh.norm(), gradl.norm()}); + + x_hat = s.unaryExpr([](double d) { + uint64_t bits = std::bit_cast(d); + // Return the sign bit: 1 for negative, 0 for positive + return (bits >> 63); + }).template cast(); + + if (check_parity(x_hat)) { + break; + } + } + + return gradient_values; + } + /** * @brief Get the values of all member variables necessary to recreate an * exact copy of this class. Used for pickling diff --git a/sw/cpp/src/python_interface.cpp b/sw/cpp/src/python_interface.cpp index c650258..4970e75 100644 --- a/sw/cpp/src/python_interface.cpp +++ b/sw/cpp/src/python_interface.cpp @@ -14,7 +14,11 @@ using namespace pybind11::literals; .def(py::init, int, double, double, double>(), \ "H"_a.noconvert(), "K"_a = 100, "omega"_a = 0.0002, \ "gamma"_a = .05, "eta"_a = 1.5) \ - .def("decode", &ProximalDecoder::decode, "x"_a.noconvert()) \ + .def("decode", &ProximalDecoder::decode, "y"_a.noconvert()) \ + .def("get_error_values", &ProximalDecoder::get_error_values, \ + "x"_a.noconvert(), "y"_a.noconvert()) \ + .def("get_gradient_values", \ + &ProximalDecoder::get_gradient_values, "y"_a.noconvert()) \ .def(py::pickle( \ [](const ProximalDecoder& a) { \ auto state = a.get_decoder_state(); \