From 5662ba841af173373fe9e4627c87e9b4b358a4b9 Mon Sep 17 00:00:00 2001 From: Andreas Tsouchlos Date: Fri, 25 Nov 2022 18:47:23 +0100 Subject: [PATCH] Made cpp ProximalDecoder picklable --- sw/cpp/src/cpp_decoders.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/sw/cpp/src/cpp_decoders.cpp b/sw/cpp/src/cpp_decoders.cpp index 4d3259d..a00422f 100644 --- a/sw/cpp/src/cpp_decoders.cpp +++ b/sw/cpp/src/cpp_decoders.cpp @@ -25,7 +25,7 @@ class ProximalDecoder { public: ProximalDecoder(const Eigen::Ref& H, int K, double omega, double gamma, double eta) - : mN(H.cols()), mK(K), mOmega(omega), mGamma(gamma), mEta(eta), mH(H), + : mK(K), mOmega(omega), mGamma(gamma), mEta(eta), mH(H), mH_zero_indices(find_zero(H)) { Eigen::setNbThreads(8); @@ -56,8 +56,7 @@ public: return {std::nullopt, mK}; } -private: - const int mN; + // private: const int mK; const double mOmega; const double mGamma; @@ -119,5 +118,15 @@ PYBIND11_MODULE(cpp_decoders, proximal) { .def(pybind11::init(), "H"_a.noconvert(), "K"_a = 100, "omega"_a = 0.0002, "gamma"_a = .05, "eta"_a = 1.5) - .def("decode", &ProximalDecoder::decode, "x"_a.noconvert()); + .def("decode", &ProximalDecoder::decode, "x"_a.noconvert()) + .def(pybind11::pickle( + [](const ProximalDecoder& a) { // dump + return pybind11::make_tuple(a.mH, a.mK, a.mOmega, a.mGamma, + a.mEta); + }, + [](pybind11::tuple t) { // load + return ProximalDecoder{t[0].cast(), t[1].cast(), + t[2].cast(), t[3].cast(), + t[4].cast()}; + })); } \ No newline at end of file