Implemented cpp ProximalDecoder get_error_values() and get_gradient_values()

This commit is contained in:
Andreas Tsouchlos 2022-12-05 14:53:14 +01:00
parent 9b9eaa6566
commit a32d5cb2c9
2 changed files with 94 additions and 1 deletions

View File

@ -2,6 +2,7 @@
#include <Eigen/Dense>
#include <bit>
#include <cstdlib>
#include <iostream>
#include <stdexcept>
@ -100,6 +101,94 @@ public:
return {x_hat, -1};
}
/**
* @brief Decode a received signal an measure error value (x - sign(x_hat))
* @param x Transmitted word
* @param y Received signal
* @return \b std::vector of error values. Each element corresponds to one
* iteration of the algorithm
*/
std::vector<double>
get_error_values(const Eigen::Ref<const RowVectori<t_n>>& x,
const Eigen::Ref<const RowVectord<t_n>>& y) {
if (y.size() != mH.cols())
throw std::runtime_error("Length of vector must match H matrix");
std::vector<double> error_values;
error_values.reserve(mK);
RowVectord<t_n> s = RowVectord<t_n>::Zero(t_n);
RowVectori<t_n> x_hat;
RowVectord<t_n> r;
for (std::size_t i = 0; i < mK; ++i) {
r = s - mOmega * L_awgn(s, y);
s = projection(r - mGamma * grad_H(r));
x_hat = s.unaryExpr([](double d) {
uint64_t bits = std::bit_cast<uint64_t>(d);
// Return the sign bit: 1 for negative, 0 for positive
return (bits >> 63);
}).template cast<int>();
RowVectord<t_n> x_hat_bpsk =
-1 * ((2 * x_hat.template cast<double>()).array() - 1).matrix();
error_values.push_back(
(x.template cast<double>() - x_hat_bpsk).norm());
if (check_parity(x_hat)) {
break;
}
}
return error_values;
}
/**
* @brief Decode a received signal and measure the norm of the two gradients
* at each iteration
* @param y
* @return \b std::vector of \b std::pair of gradient values. Each element corresponds to
* one iteration. Result is of the form [(grad_H_1, grad_L_1), ...]
*/
std::vector<std::pair<double, double>>
get_gradient_values(const Eigen::Ref<const RowVectord<t_n>>& y) {
if (y.size() != mH.cols())
throw std::runtime_error("Length of vector must match H matrix");
std::vector<std::pair<double, double>> gradient_values;
gradient_values.reserve(mK);
RowVectord<t_n> s = RowVectord<t_n>::Zero(t_n);
RowVectori<t_n> x_hat;
RowVectord<t_n> r;
for (std::size_t i = 0; i < mK; ++i) {
RowVectord<t_n> gradl = L_awgn(s, y);
r = s - mOmega * gradl;
RowVectord<t_n> gradh = grad_H(r);
s = projection(r - mGamma * gradh);
gradient_values.push_back({gradh.norm(), gradl.norm()});
x_hat = s.unaryExpr([](double d) {
uint64_t bits = std::bit_cast<uint64_t>(d);
// Return the sign bit: 1 for negative, 0 for positive
return (bits >> 63);
}).template cast<int>();
if (check_parity(x_hat)) {
break;
}
}
return gradient_values;
}
/**
* @brief Get the values of all member variables necessary to recreate an
* exact copy of this class. Used for pickling

View File

@ -14,7 +14,11 @@ using namespace pybind11::literals;
.def(py::init<MatrixiR<m, n>, int, double, double, double>(), \
"H"_a.noconvert(), "K"_a = 100, "omega"_a = 0.0002, \
"gamma"_a = .05, "eta"_a = 1.5) \
.def("decode", &ProximalDecoder<m, n>::decode, "x"_a.noconvert()) \
.def("decode", &ProximalDecoder<m, n>::decode, "y"_a.noconvert()) \
.def("get_error_values", &ProximalDecoder<m, n>::get_error_values, \
"x"_a.noconvert(), "y"_a.noconvert()) \
.def("get_gradient_values", \
&ProximalDecoder<m, n>::get_gradient_values, "y"_a.noconvert()) \
.def(py::pickle( \
[](const ProximalDecoder<m, n>& a) { \
auto state = a.get_decoder_state(); \