Implemented cpp ProximalDecoder get_error_values() and get_gradient_values()
This commit is contained in:
parent
9b9eaa6566
commit
a32d5cb2c9
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <Eigen/Dense>
|
#include <Eigen/Dense>
|
||||||
#include <bit>
|
#include <bit>
|
||||||
|
#include <cstdlib>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
@ -100,6 +101,94 @@ public:
|
|||||||
return {x_hat, -1};
|
return {x_hat, -1};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Decode a received signal an measure error value (x - sign(x_hat))
|
||||||
|
* @param x Transmitted word
|
||||||
|
* @param y Received signal
|
||||||
|
* @return \b std::vector of error values. Each element corresponds to one
|
||||||
|
* iteration of the algorithm
|
||||||
|
*/
|
||||||
|
std::vector<double>
|
||||||
|
get_error_values(const Eigen::Ref<const RowVectori<t_n>>& x,
|
||||||
|
const Eigen::Ref<const RowVectord<t_n>>& y) {
|
||||||
|
|
||||||
|
if (y.size() != mH.cols())
|
||||||
|
throw std::runtime_error("Length of vector must match H matrix");
|
||||||
|
|
||||||
|
std::vector<double> error_values;
|
||||||
|
error_values.reserve(mK);
|
||||||
|
|
||||||
|
RowVectord<t_n> s = RowVectord<t_n>::Zero(t_n);
|
||||||
|
RowVectori<t_n> x_hat;
|
||||||
|
RowVectord<t_n> r;
|
||||||
|
|
||||||
|
for (std::size_t i = 0; i < mK; ++i) {
|
||||||
|
r = s - mOmega * L_awgn(s, y);
|
||||||
|
|
||||||
|
s = projection(r - mGamma * grad_H(r));
|
||||||
|
|
||||||
|
x_hat = s.unaryExpr([](double d) {
|
||||||
|
uint64_t bits = std::bit_cast<uint64_t>(d);
|
||||||
|
// Return the sign bit: 1 for negative, 0 for positive
|
||||||
|
return (bits >> 63);
|
||||||
|
}).template cast<int>();
|
||||||
|
|
||||||
|
RowVectord<t_n> x_hat_bpsk =
|
||||||
|
-1 * ((2 * x_hat.template cast<double>()).array() - 1).matrix();
|
||||||
|
error_values.push_back(
|
||||||
|
(x.template cast<double>() - x_hat_bpsk).norm());
|
||||||
|
|
||||||
|
if (check_parity(x_hat)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return error_values;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Decode a received signal and measure the norm of the two gradients
|
||||||
|
* at each iteration
|
||||||
|
* @param y
|
||||||
|
* @return \b std::vector of \b std::pair of gradient values. Each element corresponds to
|
||||||
|
* one iteration. Result is of the form [(grad_H_1, grad_L_1), ...]
|
||||||
|
*/
|
||||||
|
std::vector<std::pair<double, double>>
|
||||||
|
get_gradient_values(const Eigen::Ref<const RowVectord<t_n>>& y) {
|
||||||
|
|
||||||
|
if (y.size() != mH.cols())
|
||||||
|
throw std::runtime_error("Length of vector must match H matrix");
|
||||||
|
|
||||||
|
std::vector<std::pair<double, double>> gradient_values;
|
||||||
|
gradient_values.reserve(mK);
|
||||||
|
|
||||||
|
RowVectord<t_n> s = RowVectord<t_n>::Zero(t_n);
|
||||||
|
RowVectori<t_n> x_hat;
|
||||||
|
RowVectord<t_n> r;
|
||||||
|
|
||||||
|
for (std::size_t i = 0; i < mK; ++i) {
|
||||||
|
RowVectord<t_n> gradl = L_awgn(s, y);
|
||||||
|
r = s - mOmega * gradl;
|
||||||
|
|
||||||
|
RowVectord<t_n> gradh = grad_H(r);
|
||||||
|
s = projection(r - mGamma * gradh);
|
||||||
|
|
||||||
|
gradient_values.push_back({gradh.norm(), gradl.norm()});
|
||||||
|
|
||||||
|
x_hat = s.unaryExpr([](double d) {
|
||||||
|
uint64_t bits = std::bit_cast<uint64_t>(d);
|
||||||
|
// Return the sign bit: 1 for negative, 0 for positive
|
||||||
|
return (bits >> 63);
|
||||||
|
}).template cast<int>();
|
||||||
|
|
||||||
|
if (check_parity(x_hat)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return gradient_values;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the values of all member variables necessary to recreate an
|
* @brief Get the values of all member variables necessary to recreate an
|
||||||
* exact copy of this class. Used for pickling
|
* exact copy of this class. Used for pickling
|
||||||
|
|||||||
@ -14,7 +14,11 @@ using namespace pybind11::literals;
|
|||||||
.def(py::init<MatrixiR<m, n>, int, double, double, double>(), \
|
.def(py::init<MatrixiR<m, n>, int, double, double, double>(), \
|
||||||
"H"_a.noconvert(), "K"_a = 100, "omega"_a = 0.0002, \
|
"H"_a.noconvert(), "K"_a = 100, "omega"_a = 0.0002, \
|
||||||
"gamma"_a = .05, "eta"_a = 1.5) \
|
"gamma"_a = .05, "eta"_a = 1.5) \
|
||||||
.def("decode", &ProximalDecoder<m, n>::decode, "x"_a.noconvert()) \
|
.def("decode", &ProximalDecoder<m, n>::decode, "y"_a.noconvert()) \
|
||||||
|
.def("get_error_values", &ProximalDecoder<m, n>::get_error_values, \
|
||||||
|
"x"_a.noconvert(), "y"_a.noconvert()) \
|
||||||
|
.def("get_gradient_values", \
|
||||||
|
&ProximalDecoder<m, n>::get_gradient_values, "y"_a.noconvert()) \
|
||||||
.def(py::pickle( \
|
.def(py::pickle( \
|
||||||
[](const ProximalDecoder<m, n>& a) { \
|
[](const ProximalDecoder<m, n>& a) { \
|
||||||
auto state = a.get_decoder_state(); \
|
auto state = a.get_decoder_state(); \
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user