Added check for vector dimensions to cpp proximal decoder
This commit is contained in:
parent
b33a0735f0
commit
31f137067c
@ -24,7 +24,7 @@ include_directories(${pybind11_INCLUDE_DIRS})
|
|||||||
|
|
||||||
find_package(OpenMP REQUIRED)
|
find_package(OpenMP REQUIRED)
|
||||||
|
|
||||||
add_compile_options(-ffast-math)
|
#add_compile_options(-ffast-math)
|
||||||
|
|
||||||
pybind11_add_module(cpp_decoders src/cpp_decoders.cpp)
|
pybind11_add_module(cpp_decoders src/cpp_decoders.cpp)
|
||||||
target_link_libraries(cpp_decoders PRIVATE Eigen3::Eigen OpenMP::OpenMP_CXX)
|
target_link_libraries(cpp_decoders PRIVATE Eigen3::Eigen OpenMP::OpenMP_CXX)
|
||||||
|
|||||||
@ -1,14 +1,13 @@
|
|||||||
#include <Eigen/Dense>
|
#include <Eigen/Dense>
|
||||||
|
#include <bit>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include <pybind11/eigen.h>
|
||||||
#include <pybind11/numpy.h>
|
#include <pybind11/numpy.h>
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
#include <pybind11/eigen.h>
|
|
||||||
|
|
||||||
#include <bit>
|
|
||||||
|
|
||||||
|
|
||||||
namespace py11 = pybind11;
|
namespace py11 = pybind11;
|
||||||
using namespace pybind11::literals;
|
using namespace pybind11::literals;
|
||||||
@ -33,6 +32,9 @@ public:
|
|||||||
|
|
||||||
std::pair<std::optional<Eigen::RowVectorXi>, int>
|
std::pair<std::optional<Eigen::RowVectorXi>, int>
|
||||||
decode(const Eigen::Ref<const Eigen::VectorXd>& y) {
|
decode(const Eigen::Ref<const Eigen::VectorXd>& y) {
|
||||||
|
if (y.size() != mH.cols())
|
||||||
|
throw std::runtime_error("Length of vector must match H matrix");
|
||||||
|
|
||||||
Eigen::RowVectorXd s = Eigen::RowVectorXd::Zero(mH.cols());
|
Eigen::RowVectorXd s = Eigen::RowVectorXd::Zero(mH.cols());
|
||||||
Eigen::RowVectorXi x_hat;
|
Eigen::RowVectorXi x_hat;
|
||||||
Eigen::RowVectorXd r;
|
Eigen::RowVectorXd r;
|
||||||
@ -129,4 +131,6 @@ PYBIND11_MODULE(cpp_decoders, proximal) {
|
|||||||
t[2].cast<double>(), t[3].cast<double>(),
|
t[2].cast<double>(), t[3].cast<double>(),
|
||||||
t[4].cast<double>()};
|
t[4].cast<double>()};
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
pybind11::register_exception<std::runtime_error>(proximal, "CppException");
|
||||||
}
|
}
|
||||||
Loading…
Reference in New Issue
Block a user