Made cpp ProximalDecoder picklable

This commit is contained in:
Andreas Tsouchlos 2022-11-25 18:47:23 +01:00
parent eb32d83ed0
commit 5662ba841a

View File

@ -25,7 +25,7 @@ class ProximalDecoder {
public:
ProximalDecoder(const Eigen::Ref<const MatrixXiR>& H, int K, double omega,
double gamma, double eta)
: mN(H.cols()), mK(K), mOmega(omega), mGamma(gamma), mEta(eta), mH(H),
: mK(K), mOmega(omega), mGamma(gamma), mEta(eta), mH(H),
mH_zero_indices(find_zero(H)) {
Eigen::setNbThreads(8);
@ -56,8 +56,7 @@ public:
return {std::nullopt, mK};
}
private:
const int mN;
// private:
const int mK;
const double mOmega;
const double mGamma;
@ -119,5 +118,15 @@ PYBIND11_MODULE(cpp_decoders, proximal) {
.def(pybind11::init<MatrixXiR, int, double, double, double>(),
"H"_a.noconvert(), "K"_a = 100, "omega"_a = 0.0002,
"gamma"_a = .05, "eta"_a = 1.5)
.def("decode", &ProximalDecoder::decode, "x"_a.noconvert());
.def("decode", &ProximalDecoder::decode, "x"_a.noconvert())
.def(pybind11::pickle(
[](const ProximalDecoder& a) { // dump
return pybind11::make_tuple(a.mH, a.mK, a.mOmega, a.mGamma,
a.mEta);
},
[](pybind11::tuple t) { // load
return ProximalDecoder{t[0].cast<MatrixXiR>(), t[1].cast<int>(),
t[2].cast<double>(), t[3].cast<double>(),
t[4].cast<double>()};
}));
}