Implemented cpp ProximalDecoder get_error_values() and get_gradient_values()
This commit is contained in:
parent
9b9eaa6566
commit
a32d5cb2c9
@ -2,6 +2,7 @@
|
||||
|
||||
#include <Eigen/Dense>
|
||||
#include <bit>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
@ -100,6 +101,94 @@ public:
|
||||
return {x_hat, -1};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Decode a received signal an measure error value (x - sign(x_hat))
|
||||
* @param x Transmitted word
|
||||
* @param y Received signal
|
||||
* @return \b std::vector of error values. Each element corresponds to one
|
||||
* iteration of the algorithm
|
||||
*/
|
||||
std::vector<double>
|
||||
get_error_values(const Eigen::Ref<const RowVectori<t_n>>& x,
|
||||
const Eigen::Ref<const RowVectord<t_n>>& y) {
|
||||
|
||||
if (y.size() != mH.cols())
|
||||
throw std::runtime_error("Length of vector must match H matrix");
|
||||
|
||||
std::vector<double> error_values;
|
||||
error_values.reserve(mK);
|
||||
|
||||
RowVectord<t_n> s = RowVectord<t_n>::Zero(t_n);
|
||||
RowVectori<t_n> x_hat;
|
||||
RowVectord<t_n> r;
|
||||
|
||||
for (std::size_t i = 0; i < mK; ++i) {
|
||||
r = s - mOmega * L_awgn(s, y);
|
||||
|
||||
s = projection(r - mGamma * grad_H(r));
|
||||
|
||||
x_hat = s.unaryExpr([](double d) {
|
||||
uint64_t bits = std::bit_cast<uint64_t>(d);
|
||||
// Return the sign bit: 1 for negative, 0 for positive
|
||||
return (bits >> 63);
|
||||
}).template cast<int>();
|
||||
|
||||
RowVectord<t_n> x_hat_bpsk =
|
||||
-1 * ((2 * x_hat.template cast<double>()).array() - 1).matrix();
|
||||
error_values.push_back(
|
||||
(x.template cast<double>() - x_hat_bpsk).norm());
|
||||
|
||||
if (check_parity(x_hat)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return error_values;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Decode a received signal and measure the norm of the two gradients
|
||||
* at each iteration
|
||||
* @param y
|
||||
* @return \b std::vector of \b std::pair of gradient values. Each element corresponds to
|
||||
* one iteration. Result is of the form [(grad_H_1, grad_L_1), ...]
|
||||
*/
|
||||
std::vector<std::pair<double, double>>
|
||||
get_gradient_values(const Eigen::Ref<const RowVectord<t_n>>& y) {
|
||||
|
||||
if (y.size() != mH.cols())
|
||||
throw std::runtime_error("Length of vector must match H matrix");
|
||||
|
||||
std::vector<std::pair<double, double>> gradient_values;
|
||||
gradient_values.reserve(mK);
|
||||
|
||||
RowVectord<t_n> s = RowVectord<t_n>::Zero(t_n);
|
||||
RowVectori<t_n> x_hat;
|
||||
RowVectord<t_n> r;
|
||||
|
||||
for (std::size_t i = 0; i < mK; ++i) {
|
||||
RowVectord<t_n> gradl = L_awgn(s, y);
|
||||
r = s - mOmega * gradl;
|
||||
|
||||
RowVectord<t_n> gradh = grad_H(r);
|
||||
s = projection(r - mGamma * gradh);
|
||||
|
||||
gradient_values.push_back({gradh.norm(), gradl.norm()});
|
||||
|
||||
x_hat = s.unaryExpr([](double d) {
|
||||
uint64_t bits = std::bit_cast<uint64_t>(d);
|
||||
// Return the sign bit: 1 for negative, 0 for positive
|
||||
return (bits >> 63);
|
||||
}).template cast<int>();
|
||||
|
||||
if (check_parity(x_hat)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return gradient_values;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the values of all member variables necessary to recreate an
|
||||
* exact copy of this class. Used for pickling
|
||||
|
||||
@ -14,7 +14,11 @@ using namespace pybind11::literals;
|
||||
.def(py::init<MatrixiR<m, n>, int, double, double, double>(), \
|
||||
"H"_a.noconvert(), "K"_a = 100, "omega"_a = 0.0002, \
|
||||
"gamma"_a = .05, "eta"_a = 1.5) \
|
||||
.def("decode", &ProximalDecoder<m, n>::decode, "x"_a.noconvert()) \
|
||||
.def("decode", &ProximalDecoder<m, n>::decode, "y"_a.noconvert()) \
|
||||
.def("get_error_values", &ProximalDecoder<m, n>::get_error_values, \
|
||||
"x"_a.noconvert(), "y"_a.noconvert()) \
|
||||
.def("get_gradient_values", \
|
||||
&ProximalDecoder<m, n>::get_gradient_values, "y"_a.noconvert()) \
|
||||
.def(py::pickle( \
|
||||
[](const ProximalDecoder<m, n>& a) { \
|
||||
auto state = a.get_decoder_state(); \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user