diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..67fe8eb --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +build/ +.cache/ +__pycache__/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e1554d0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,4 @@ +FROM alpine:3.21.3 + +RUN apk update +RUN apk add cmake make g++ python3 eigen-dev spdlog-dev diff --git a/LICENSE b/LICENSE index 9327b75..5e00f8c 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2025 an.tsouchlos +Copyright (c) 2025 Andreas Tsouchlos Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/README.md b/README.md index 56d774f..e6d5fbe 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,42 @@ -# homotopy-continuation-channel-decoding +# Homotopy Continuation Channel Decoding -A project attempting to use homotopy continuation to perform channel decoding. +A project using homotopy continuation methods to perform channel decoding. + +This repository contains implementations in multiple programming languages, +that are independent of each other. E.g., for each programming language there +are dedicated examples. + +| Directory | Content | +| ---------- | ------------------ | +| `docs/` | Documentation | +| `scripts/` | Utility scripts | +| `cpp/` | C++ source code | +| `python/` | Python source code | + +## Use Python Code + +```bash +$ python python/examples/toy_homotopy.py -o temp.csv +$ python scripts/plot_solution_curve.py -i temp.csv +``` + +## Use C++ Code + +#### On Host System + +```bash +$ cmake -B build -S cpp +$ cmake --build build +$ ./build/toy_example_cpp/toy_example -o temp.csv +$ python scripts/plot_solution_curve.py -i temp.csv +``` + +#### In Docker Container + +```bash +$ docker build . -t hccd +$ docker run --rm -it -u `id -u`:`id -g` -v $PWD:$PWD -w $PWD hccd cmake -B build -S cpp +$ docker run --rm -it -u `id -u`:`id -g` -v $PWD:$PWD -w $PWD hccd cmake --build build +$ docker run --rm -it -u `id -u`:`id -g` -v $PWD:$PWD -w $PWD hccd ./build/toy_example_cpp/toy_example -o temp.csv +$ python scripts/plot_solution_curve.py -i temp.csv +``` diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt new file mode 100644 index 0000000..94cd427 --- /dev/null +++ b/cpp/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.22) + +project(HomotopyContinuation) + +set(CMAKE_CXX_STANDARD 23) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +find_package(spdlog REQUIRED) +find_package(Eigen3 REQUIRED) + +add_library(hccd INTERFACE) +target_include_directories(hccd INTERFACE include) + +add_subdirectory(examples) diff --git a/cpp/examples/CMakeLists.txt b/cpp/examples/CMakeLists.txt new file mode 100644 index 0000000..2724068 --- /dev/null +++ b/cpp/examples/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy_homotopy) diff --git a/cpp/examples/toy_homotopy/CMakeLists.txt b/cpp/examples/toy_homotopy/CMakeLists.txt new file mode 100644 index 0000000..a50aba9 --- /dev/null +++ b/cpp/examples/toy_homotopy/CMakeLists.txt @@ -0,0 +1,7 @@ +set(CMAKE_BuilD_TYPE Release) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +add_executable(toy_homotopy toy_homotopy.cpp) +target_link_libraries(toy_homotopy PRIVATE hccd spdlog::spdlog Eigen3::Eigen) +target_include_directories(toy_homotopy PRIVATE hccd lib/argparse/include) diff --git a/cpp/examples/toy_homotopy/argparse.hpp b/cpp/examples/toy_homotopy/argparse.hpp new file mode 100644 index 0000000..aaff069 --- /dev/null +++ b/cpp/examples/toy_homotopy/argparse.hpp @@ -0,0 +1,2739 @@ +/* + __ _ _ __ __ _ _ __ __ _ _ __ ___ ___ + / _` | '__/ _` | '_ \ / _` | '__/ __|/ _ \ Argument Parser for Modern C++ +| (_| | | | (_| | |_) | (_| | | \__ \ __/ http://github.com/p-ranav/argparse + \__,_|_| \__, | .__/ \__,_|_| |___/\___| + |___/|_| + +Licensed under the MIT License . +SPDX-License-Identifier: MIT +Copyright (c) 2019-2022 Pranav Srinivas Kumar +and other contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ +#pragma once + +#include + +#ifndef ARGPARSE_MODULE_USE_STD_MODULE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#ifndef ARGPARSE_CUSTOM_STRTOF +#define ARGPARSE_CUSTOM_STRTOF strtof +#endif + +#ifndef ARGPARSE_CUSTOM_STRTOD +#define ARGPARSE_CUSTOM_STRTOD strtod +#endif + +#ifndef ARGPARSE_CUSTOM_STRTOLD +#define ARGPARSE_CUSTOM_STRTOLD strtold +#endif + +namespace argparse { + +namespace details { // namespace for helper methods + +template +struct HasContainerTraits : std::false_type {}; + +template <> +struct HasContainerTraits : std::false_type {}; + +template <> +struct HasContainerTraits : std::false_type {}; + +template +struct HasContainerTraits< + T, std::void_t().begin()), + decltype(std::declval().end()), + decltype(std::declval().size())>> : std::true_type {}; + +template +inline constexpr bool IsContainer = HasContainerTraits::value; + +template +struct HasStreamableTraits : std::false_type {}; + +template +struct HasStreamableTraits() + << std::declval())>> + : std::true_type {}; + +template +inline constexpr bool IsStreamable = HasStreamableTraits::value; + +constexpr std::size_t repr_max_container_size = 5; + +template +std::string repr(T const& val) { + if constexpr (std::is_same_v) { + return val ? "true" : "false"; + } else if constexpr (std::is_convertible_v) { + return '"' + std::string{std::string_view{val}} + '"'; + } else if constexpr (IsContainer) { + std::stringstream out; + out << "{"; + const auto size = val.size(); + if (size > 1) { + out << repr(*val.begin()); + std::for_each( + std::next(val.begin()), + std::next( + val.begin(), + static_cast( + std::min(size, repr_max_container_size) - + 1)), + [&out](const auto& v) { + out << " " << repr(v); + }); + if (size <= repr_max_container_size) { + out << " "; + } else { + out << "..."; + } + } + if (size > 0) { + out << repr(*std::prev(val.end())); + } + out << "}"; + return out.str(); + } else if constexpr (IsStreamable) { + std::stringstream out; + out << val; + return out.str(); + } else { + return ""; + } +} + +namespace { + +template +constexpr bool standard_signed_integer = false; +template <> +constexpr bool standard_signed_integer = true; +template <> +constexpr bool standard_signed_integer = true; +template <> +constexpr bool standard_signed_integer = true; +template <> +constexpr bool standard_signed_integer = true; +template <> +constexpr bool standard_signed_integer = true; + +template +constexpr bool standard_unsigned_integer = false; +template <> +constexpr bool standard_unsigned_integer = true; +template <> +constexpr bool standard_unsigned_integer = true; +template <> +constexpr bool standard_unsigned_integer = true; +template <> +constexpr bool standard_unsigned_integer = true; +template <> +constexpr bool standard_unsigned_integer = true; + +} // namespace + +constexpr int radix_2 = 2; +constexpr int radix_8 = 8; +constexpr int radix_10 = 10; +constexpr int radix_16 = 16; + +template +constexpr bool standard_integer = + standard_signed_integer || standard_unsigned_integer; + +template +constexpr decltype(auto) +apply_plus_one_impl(F&& f, Tuple&& t, Extra&& x, + std::index_sequence /*unused*/) { + return std::invoke(std::forward(f), + std::get(std::forward(t))..., + std::forward(x)); +} + +template +constexpr decltype(auto) apply_plus_one(F&& f, Tuple&& t, Extra&& x) { + return details::apply_plus_one_impl( + std::forward(f), std::forward(t), std::forward(x), + std::make_index_sequence< + std::tuple_size_v>>{}); +} + +constexpr auto pointer_range(std::string_view s) noexcept { + return std::tuple(s.data(), s.data() + s.size()); +} + +template +constexpr bool starts_with(std::basic_string_view prefix, + std::basic_string_view s) noexcept { + return s.substr(0, prefix.size()) == prefix; +} + +enum class chars_format { + scientific = 0xf1, + fixed = 0xf2, + hex = 0xf4, + binary = 0xf8, + general = fixed | scientific +}; + +struct ConsumeBinaryPrefixResult { + bool is_binary; + std::string_view rest; +}; + +constexpr auto consume_binary_prefix(std::string_view s) + -> ConsumeBinaryPrefixResult { + if (starts_with(std::string_view{"0b"}, s) || + starts_with(std::string_view{"0B"}, s)) { + s.remove_prefix(2); + return {true, s}; + } + return {false, s}; +} + +struct ConsumeHexPrefixResult { + bool is_hexadecimal; + std::string_view rest; +}; + +using namespace std::literals; + +constexpr auto consume_hex_prefix(std::string_view s) + -> ConsumeHexPrefixResult { + if (starts_with("0x"sv, s) || starts_with("0X"sv, s)) { + s.remove_prefix(2); + return {true, s}; + } + return {false, s}; +} + +template +inline auto do_from_chars(std::string_view s) -> T { + T x{0}; + auto [first, last] = pointer_range(s); + auto [ptr, ec] = std::from_chars(first, last, x, Param); + if (ec == std::errc()) { + if (ptr == last) { + return x; + } + throw std::invalid_argument{"pattern '" + std::string(s) + + "' does not match to the end"}; + } + if (ec == std::errc::invalid_argument) { + throw std::invalid_argument{"pattern '" + std::string(s) + + "' not found"}; + } + if (ec == std::errc::result_out_of_range) { + throw std::range_error{"'" + std::string(s) + "' not representable"}; + } + return x; // unreachable +} + +template +struct parse_number { + auto operator()(std::string_view s) -> T { + return do_from_chars(s); + } +}; + +template +struct parse_number { + auto operator()(std::string_view s) -> T { + if (auto [ok, rest] = consume_binary_prefix(s); ok) { + return do_from_chars(rest); + } + throw std::invalid_argument{"pattern not found"}; + } +}; + +template +struct parse_number { + auto operator()(std::string_view s) -> T { + if (starts_with("0x"sv, s) || starts_with("0X"sv, s)) { + if (auto [ok, rest] = consume_hex_prefix(s); ok) { + try { + return do_from_chars(rest); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument( + "Failed to parse '" + std::string(s) + + "' as hexadecimal: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + + std::string(s) + + "' as hexadecimal: " + err.what()); + } + } + } else { + // Allow passing hex numbers without prefix + // Shape 'x' already has to be specified + try { + return do_from_chars(s); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument("Failed to parse '" + + std::string(s) + + "' as hexadecimal: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as hexadecimal: " + err.what()); + } + } + + throw std::invalid_argument{"pattern '" + std::string(s) + + "' not identified as hexadecimal"}; + } +}; + +template +struct parse_number { + auto operator()(std::string_view s) -> T { + auto [ok, rest] = consume_hex_prefix(s); + if (ok) { + try { + return do_from_chars(rest); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument("Failed to parse '" + + std::string(s) + + "' as hexadecimal: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as hexadecimal: " + err.what()); + } + } + + auto [ok_binary, rest_binary] = consume_binary_prefix(s); + if (ok_binary) { + try { + return do_from_chars(rest_binary); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument("Failed to parse '" + + std::string(s) + + "' as binary: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as binary: " + err.what()); + } + } + + if (starts_with("0"sv, s)) { + try { + return do_from_chars(rest); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument("Failed to parse '" + + std::string(s) + + "' as octal: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as octal: " + err.what()); + } + } + + try { + return do_from_chars(rest); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument("Failed to parse '" + std::string(s) + + "' as decimal integer: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as decimal integer: " + err.what()); + } + } +}; + +namespace { + +template +inline const auto generic_strtod = nullptr; +template <> +inline const auto generic_strtod = ARGPARSE_CUSTOM_STRTOF; +template <> +inline const auto generic_strtod = ARGPARSE_CUSTOM_STRTOD; +template <> +inline const auto generic_strtod = ARGPARSE_CUSTOM_STRTOLD; + +} // namespace + +template +inline auto do_strtod(std::string const& s) -> T { + if (isspace(static_cast(s[0])) || s[0] == '+') { + throw std::invalid_argument{"pattern '" + s + "' not found"}; + } + + auto [first, last] = pointer_range(s); + char* ptr; + + errno = 0; + auto x = generic_strtod(first, &ptr); + if (errno == 0) { + if (ptr == last) { + return x; + } + throw std::invalid_argument{"pattern '" + s + + "' does not match to the end"}; + } + if (errno == ERANGE) { + throw std::range_error{"'" + s + "' not representable"}; + } + return x; // unreachable +} + +template +struct parse_number { + auto operator()(std::string const& s) -> T { + if (auto r = consume_hex_prefix(s); r.is_hexadecimal) { + throw std::invalid_argument{ + "chars_format::general does not parse hexfloat"}; + } + if (auto r = consume_binary_prefix(s); r.is_binary) { + throw std::invalid_argument{ + "chars_format::general does not parse binfloat"}; + } + + try { + return do_strtod(s); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument("Failed to parse '" + s + + "' as number: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + s + + "' as number: " + err.what()); + } + } +}; + +template +struct parse_number { + auto operator()(std::string const& s) -> T { + if (auto r = consume_hex_prefix(s); !r.is_hexadecimal) { + throw std::invalid_argument{"chars_format::hex parses hexfloat"}; + } + if (auto r = consume_binary_prefix(s); r.is_binary) { + throw std::invalid_argument{ + "chars_format::hex does not parse binfloat"}; + } + + try { + return do_strtod(s); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument("Failed to parse '" + s + + "' as hexadecimal: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + s + + "' as hexadecimal: " + err.what()); + } + } +}; + +template +struct parse_number { + auto operator()(std::string const& s) -> T { + if (auto r = consume_hex_prefix(s); r.is_hexadecimal) { + throw std::invalid_argument{ + "chars_format::binary does not parse hexfloat"}; + } + if (auto r = consume_binary_prefix(s); !r.is_binary) { + throw std::invalid_argument{"chars_format::binary parses binfloat"}; + } + + return do_strtod(s); + } +}; + +template +struct parse_number { + auto operator()(std::string const& s) -> T { + if (auto r = consume_hex_prefix(s); r.is_hexadecimal) { + throw std::invalid_argument{ + "chars_format::scientific does not parse hexfloat"}; + } + if (auto r = consume_binary_prefix(s); r.is_binary) { + throw std::invalid_argument{ + "chars_format::scientific does not parse binfloat"}; + } + if (s.find_first_of("eE") == std::string::npos) { + throw std::invalid_argument{ + "chars_format::scientific requires exponent part"}; + } + + try { + return do_strtod(s); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument( + "Failed to parse '" + s + + "' as scientific notation: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + s + + "' as scientific notation: " + err.what()); + } + } +}; + +template +struct parse_number { + auto operator()(std::string const& s) -> T { + if (auto r = consume_hex_prefix(s); r.is_hexadecimal) { + throw std::invalid_argument{ + "chars_format::fixed does not parse hexfloat"}; + } + if (auto r = consume_binary_prefix(s); r.is_binary) { + throw std::invalid_argument{ + "chars_format::fixed does not parse binfloat"}; + } + if (s.find_first_of("eE") != std::string::npos) { + throw std::invalid_argument{ + "chars_format::fixed does not parse exponent part"}; + } + + try { + return do_strtod(s); + } catch (const std::invalid_argument& err) { + throw std::invalid_argument("Failed to parse '" + s + + "' as fixed notation: " + err.what()); + } catch (const std::range_error& err) { + throw std::range_error("Failed to parse '" + s + + "' as fixed notation: " + err.what()); + } + } +}; + +template +std::string join(StrIt first, StrIt last, const std::string& separator) { + if (first == last) { + return ""; + } + std::stringstream value; + value << *first; + ++first; + while (first != last) { + value << separator << *first; + ++first; + } + return value.str(); +} + +template +struct can_invoke_to_string { + template + static auto test(int) + -> decltype(std::to_string(std::declval()), std::true_type{}); + + template + static auto test(...) -> std::false_type; + + static constexpr bool value = decltype(test(0))::value; +}; + +template +struct IsChoiceTypeSupported { + using CleanType = typename std::decay::type; + static const bool value = + std::is_integral::value || + std::is_same::value || + std::is_same::value || + std::is_same::value; +}; + +template +std::size_t get_levenshtein_distance(const StringType& s1, + const StringType& s2) { + std::vector> dp( + s1.size() + 1, std::vector(s2.size() + 1, 0)); + + for (std::size_t i = 0; i <= s1.size(); ++i) { + for (std::size_t j = 0; j <= s2.size(); ++j) { + if (i == 0) { + dp[i][j] = j; + } else if (j == 0) { + dp[i][j] = i; + } else if (s1[i - 1] == s2[j - 1]) { + dp[i][j] = dp[i - 1][j - 1]; + } else { + dp[i][j] = + 1 + std::min( + {dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]}); + } + } + } + + return dp[s1.size()][s2.size()]; +} + +template +std::string get_most_similar_string(const std::map& map, + const std::string& input) { + std::string most_similar{}; + std::size_t min_distance = (std::numeric_limits::max)(); + + for (const auto& entry : map) { + std::size_t distance = get_levenshtein_distance(entry.first, input); + if (distance < min_distance) { + min_distance = distance; + most_similar = entry.first; + } + } + + return most_similar; +} + +} // namespace details + +enum class nargs_pattern { optional, any, at_least_one }; + +enum class default_arguments : unsigned int { + none = 0, + help = 1, + version = 2, + all = help | version, +}; + +inline default_arguments operator&(const default_arguments& a, + const default_arguments& b) { + return static_cast( + static_cast::type>(a) & + static_cast::type>(b)); +} + +class ArgumentParser; + +class Argument { + friend class ArgumentParser; + friend auto operator<<(std::ostream& stream, const ArgumentParser& parser) + -> std::ostream&; + + template + explicit Argument(std::string_view prefix_chars, + std::array&& a, + std::index_sequence /*unused*/) + : m_accepts_optional_like_value(false), + m_is_optional((is_optional(a[I], prefix_chars) || ...)), + m_is_required(false), m_is_repeatable(false), m_is_used(false), + m_is_hidden(false), m_prefix_chars(prefix_chars) { + ((void)m_names.emplace_back(a[I]), ...); + std::sort(m_names.begin(), m_names.end(), + [](const auto& lhs, const auto& rhs) { + return lhs.size() == rhs.size() ? lhs < rhs + : lhs.size() < rhs.size(); + }); + } + +public: + template + explicit Argument(std::string_view prefix_chars, + std::array&& a) + : Argument(prefix_chars, std::move(a), std::make_index_sequence{}) { + } + + Argument& help(std::string help_text) { + m_help = std::move(help_text); + return *this; + } + + Argument& metavar(std::string metavar) { + m_metavar = std::move(metavar); + return *this; + } + + template + Argument& default_value(T&& value) { + m_num_args_range = NArgsRange{0, m_num_args_range.get_max()}; + m_default_value_repr = details::repr(value); + + if constexpr (std::is_convertible_v) { + m_default_value_str = std::string{std::string_view{value}}; + } else if constexpr (details::can_invoke_to_string::value) { + m_default_value_str = std::to_string(value); + } + + m_default_value = std::forward(value); + return *this; + } + + Argument& default_value(const char* value) { + return default_value(std::string(value)); + } + + Argument& required() { + m_is_required = true; + return *this; + } + + Argument& implicit_value(std::any value) { + m_implicit_value = std::move(value); + m_num_args_range = NArgsRange{0, 0}; + return *this; + } + + // This is shorthand for: + // program.add_argument("foo") + // .default_value(false) + // .implicit_value(true) + Argument& flag() { + default_value(false); + implicit_value(true); + return *this; + } + + template + auto action(F&& callable, Args&&... bound_args) + -> std::enable_if_t, + Argument&> { + using action_type = std::conditional_t< + std::is_void_v>, + void_action, valued_action>; + if constexpr (sizeof...(Args) == 0) { + m_actions.emplace_back(std::forward(callable)); + } else { + m_actions.emplace_back( + [f = std::forward(callable), + tup = std::make_tuple(std::forward(bound_args)...)]( + std::string const& opt) mutable { + return details::apply_plus_one(f, tup, opt); + }); + } + return *this; + } + + auto& store_into(bool& var) { + if ((!m_default_value.has_value()) && (!m_implicit_value.has_value())) { + flag(); + } + if (m_default_value.has_value()) { + var = std::any_cast(m_default_value); + } + action([&var](const auto& /*unused*/) { + var = true; + return var; + }); + return *this; + } + + template ::value>::type* = nullptr> + auto& store_into(T& var) { + if (m_default_value.has_value()) { + var = std::any_cast(m_default_value); + } + action([&var](const auto& s) { + var = details::parse_number()(s); + return var; + }); + return *this; + } + + template ::value>::type* = + nullptr> + auto& store_into(T& var) { + if (m_default_value.has_value()) { + var = std::any_cast(m_default_value); + } + action([&var](const auto& s) { + var = details::parse_number()(s); + return var; + }); + return *this; + } + + auto& store_into(std::string& var) { + if (m_default_value.has_value()) { + var = std::any_cast(m_default_value); + } + action([&var](const std::string& s) { + var = s; + return var; + }); + return *this; + } + + auto& store_into(std::filesystem::path& var) { + if (m_default_value.has_value()) { + var = std::any_cast(m_default_value); + } + action([&var](const std::string& s) { + var = s; + }); + return *this; + } + + auto& store_into(std::vector& var) { + if (m_default_value.has_value()) { + var = std::any_cast>(m_default_value); + } + action([this, &var](const std::string& s) { + if (!m_is_used) { + var.clear(); + } + m_is_used = true; + var.push_back(s); + return var; + }); + return *this; + } + + auto& store_into(std::vector& var) { + if (m_default_value.has_value()) { + var = std::any_cast>(m_default_value); + } + action([this, &var](const std::string& s) { + if (!m_is_used) { + var.clear(); + } + m_is_used = true; + var.push_back(details::parse_number()(s)); + return var; + }); + return *this; + } + + auto& store_into(std::set& var) { + if (m_default_value.has_value()) { + var = std::any_cast>(m_default_value); + } + action([this, &var](const std::string& s) { + if (!m_is_used) { + var.clear(); + } + m_is_used = true; + var.insert(s); + return var; + }); + return *this; + } + + auto& store_into(std::set& var) { + if (m_default_value.has_value()) { + var = std::any_cast>(m_default_value); + } + action([this, &var](const std::string& s) { + if (!m_is_used) { + var.clear(); + } + m_is_used = true; + var.insert(details::parse_number()(s)); + return var; + }); + return *this; + } + + auto& append() { + m_is_repeatable = true; + return *this; + } + + // Cause the argument to be invisible in usage and help + auto& hidden() { + m_is_hidden = true; + return *this; + } + + template + auto scan() -> std::enable_if_t, Argument&> { + static_assert(!(std::is_const_v || std::is_volatile_v), + "T should not be cv-qualified"); + auto is_one_of = [](char c, auto... x) constexpr { + return ((c == x) || ...); + }; + + if constexpr (is_one_of(Shape, 'd') && details::standard_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'i') && + details::standard_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'u') && + details::standard_unsigned_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'b') && + details::standard_unsigned_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'o') && + details::standard_unsigned_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'x', 'X') && + details::standard_unsigned_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'a', 'A') && + std::is_floating_point_v) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'e', 'E') && + std::is_floating_point_v) { + action( + details::parse_number()); + } else if constexpr (is_one_of(Shape, 'f', 'F') && + std::is_floating_point_v) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'g', 'G') && + std::is_floating_point_v) { + action(details::parse_number()); + } else { + static_assert(alignof(T) == 0, "No scan specification for T"); + } + + return *this; + } + + Argument& nargs(std::size_t num_args) { + m_num_args_range = NArgsRange{num_args, num_args}; + return *this; + } + + Argument& nargs(std::size_t num_args_min, std::size_t num_args_max) { + m_num_args_range = NArgsRange{num_args_min, num_args_max}; + return *this; + } + + Argument& nargs(nargs_pattern pattern) { + switch (pattern) { + case nargs_pattern::optional: + m_num_args_range = NArgsRange{0, 1}; + break; + case nargs_pattern::any: + m_num_args_range = + NArgsRange{0, (std::numeric_limits::max)()}; + break; + case nargs_pattern::at_least_one: + m_num_args_range = + NArgsRange{1, (std::numeric_limits::max)()}; + break; + } + return *this; + } + + Argument& remaining() { + m_accepts_optional_like_value = true; + return nargs(nargs_pattern::any); + } + + template + void add_choice(T&& choice) { + static_assert(details::IsChoiceTypeSupported::value, + "Only string or integer type supported for choice"); + static_assert(std::is_convertible_v || + details::can_invoke_to_string::value, + "Choice is not convertible to string_type"); + if (!m_choices.has_value()) { + m_choices = std::vector{}; + } + + if constexpr (std::is_convertible_v) { + m_choices.value().push_back( + std::string{std::string_view{std::forward(choice)}}); + } else if constexpr (details::can_invoke_to_string::value) { + m_choices.value().push_back( + std::to_string(std::forward(choice))); + } + } + + Argument& choices() { + if (!m_choices.has_value()) { + throw std::runtime_error("Zero choices provided"); + } + return *this; + } + + template + Argument& choices(T&& first, U&&... rest) { + add_choice(std::forward(first)); + choices(std::forward(rest)...); + return *this; + } + + void find_default_value_in_choices_or_throw() const { + + const auto& choices = m_choices.value(); + + if (m_default_value.has_value()) { + if (std::find(choices.begin(), choices.end(), + m_default_value_str) == choices.end()) { + // provided arg not in list of allowed choices + // report error + + std::string choices_as_csv = std::accumulate( + choices.begin(), choices.end(), std::string(), + [](const std::string& a, const std::string& b) { + return a + (a.empty() ? "" : ", ") + b; + }); + + throw std::runtime_error(std::string{"Invalid default value "} + + m_default_value_repr + + " - allowed options: {" + + choices_as_csv + "}"); + } + } + } + + template + bool is_value_in_choices(Iterator option_it) const { + + const auto& choices = m_choices.value(); + + return (std::find(choices.begin(), choices.end(), *option_it) != + choices.end()); + } + + template + void throw_invalid_arguments_error(Iterator option_it) const { + const auto& choices = m_choices.value(); + const std::string choices_as_csv = std::accumulate( + choices.begin(), choices.end(), std::string(), + [](const std::string& option_a, const std::string& option_b) { + return option_a + (option_a.empty() ? "" : ", ") + option_b; + }); + + throw std::runtime_error( + std::string{"Invalid argument "} + details::repr(*option_it) + + " - allowed options: {" + choices_as_csv + "}"); + } + + /* The dry_run parameter can be set to true to avoid running the actions, + * and setting m_is_used. This may be used by a pre-processing step to do + * a first iteration over arguments. + */ + template + Iterator consume(Iterator start, Iterator end, + std::string_view used_name = {}, bool dry_run = false) { + if (!m_is_repeatable && m_is_used) { + throw std::runtime_error( + std::string("Duplicate argument ").append(used_name)); + } + m_used_name = used_name; + + std::size_t passed_options = 0; + + if (m_choices.has_value()) { + // Check each value in (start, end) and make sure + // it is in the list of allowed choices/options + const auto max_number_of_args = m_num_args_range.get_max(); + const auto min_number_of_args = m_num_args_range.get_min(); + for (auto it = start; it != end; ++it) { + if (is_value_in_choices(it)) { + passed_options += 1; + continue; + } + + if ((passed_options >= min_number_of_args) && + (passed_options <= max_number_of_args)) { + break; + } + + throw_invalid_arguments_error(it); + } + } + + const auto num_args_max = (m_choices.has_value()) + ? passed_options + : m_num_args_range.get_max(); + const auto num_args_min = m_num_args_range.get_min(); + std::size_t dist = 0; + if (num_args_max == 0) { + if (!dry_run) { + m_values.emplace_back(m_implicit_value); + for (auto& action : m_actions) { + std::visit( + [&](const auto& f) { + f({}); + }, + action); + } + if (m_actions.empty()) { + std::visit( + [&](const auto& f) { + f({}); + }, + m_default_action); + } + m_is_used = true; + } + return start; + } + if ((dist = static_cast(std::distance(start, end))) >= + num_args_min) { + if (num_args_max < dist) { + end = std::next(start, + static_cast( + num_args_max)); + } + if (!m_accepts_optional_like_value) { + end = std::find_if(start, end, + std::bind(is_optional, std::placeholders::_1, + m_prefix_chars)); + dist = static_cast(std::distance(start, end)); + if (dist < num_args_min) { + throw std::runtime_error("Too few arguments for '" + + std::string(m_used_name) + "'."); + } + } + struct ActionApply { + void operator()(valued_action& f) { + std::transform(first, last, + std::back_inserter(self.m_values), f); + } + + void operator()(void_action& f) { + std::for_each(first, last, f); + if (!self.m_default_value.has_value()) { + if (!self.m_accepts_optional_like_value) { + self.m_values.resize(static_cast( + std::distance(first, last))); + } + } + } + + Iterator first, last; + Argument& self; + }; + if (!dry_run) { + for (auto& action : m_actions) { + std::visit(ActionApply{start, end, *this}, action); + } + if (m_actions.empty()) { + std::visit(ActionApply{start, end, *this}, + m_default_action); + } + m_is_used = true; + } + return end; + } + if (m_default_value.has_value()) { + if (!dry_run) { + m_is_used = true; + } + return start; + } + throw std::runtime_error("Too few arguments for '" + + std::string(m_used_name) + "'."); + } + + /* + * @throws std::runtime_error if argument values are not valid + */ + void validate() const { + if (m_is_optional) { + // TODO: check if an implicit value was programmed for this argument + if (!m_is_used && !m_default_value.has_value() && m_is_required) { + throw_required_arg_not_used_error(); + } + if (m_is_used && m_is_required && m_values.empty()) { + throw_required_arg_no_value_provided_error(); + } + } else { + if (!m_num_args_range.contains(m_values.size()) && + !m_default_value.has_value()) { + throw_nargs_range_validation_error(); + } + } + + if (m_choices.has_value()) { + // Make sure the default value (if provided) + // is in the list of choices + find_default_value_in_choices_or_throw(); + } + } + + std::string get_names_csv(char separator = ',') const { + return std::accumulate( + m_names.begin(), m_names.end(), std::string{""}, + [&](const std::string& result, const std::string& name) { + return result.empty() ? name : result + separator + name; + }); + } + + std::string get_usage_full() const { + std::stringstream usage; + + usage << get_names_csv('/'); + const std::string metavar = !m_metavar.empty() ? m_metavar : "VAR"; + if (m_num_args_range.get_max() > 0) { + usage << " " << metavar; + if (m_num_args_range.get_max() > 1) { + usage << "..."; + } + } + return usage.str(); + } + + std::string get_inline_usage() const { + std::stringstream usage; + // Find the longest variant to show in the usage string + std::string longest_name = m_names.front(); + for (const auto& s : m_names) { + if (s.size() > longest_name.size()) { + longest_name = s; + } + } + if (!m_is_required) { + usage << "["; + } + usage << longest_name; + const std::string metavar = !m_metavar.empty() ? m_metavar : "VAR"; + if (m_num_args_range.get_max() > 0) { + usage << " " << metavar; + if (m_num_args_range.get_max() > 1 && + m_metavar.find("> <") == std::string::npos) { + usage << "..."; + } + } + if (!m_is_required) { + usage << "]"; + } + if (m_is_repeatable) { + usage << "..."; + } + return usage.str(); + } + + std::size_t get_arguments_length() const { + + std::size_t names_size = + std::accumulate(std::begin(m_names), std::end(m_names), + std::size_t(0), [](const auto& sum, const auto& s) { + return sum + s.size(); + }); + + if (is_positional(m_names.front(), m_prefix_chars)) { + // A set metavar means this replaces the names + if (!m_metavar.empty()) { + // Indent and metavar + return 2 + m_metavar.size(); + } + + // Indent and space-separated + return 2 + names_size + (m_names.size() - 1); + } + // Is an option - include both names _and_ metavar + // size = text + (", " between names) + std::size_t size = names_size + 2 * (m_names.size() - 1); + if (!m_metavar.empty() && m_num_args_range == NArgsRange{1, 1}) { + size += m_metavar.size() + 1; + } + return size + 2; // indent + } + + friend std::ostream& operator<<(std::ostream& stream, + const Argument& argument) { + std::stringstream name_stream; + name_stream << " "; // indent + if (argument.is_positional(argument.m_names.front(), + argument.m_prefix_chars)) { + if (!argument.m_metavar.empty()) { + name_stream << argument.m_metavar; + } else { + name_stream << details::join(argument.m_names.begin(), + argument.m_names.end(), " "); + } + } else { + name_stream << details::join(argument.m_names.begin(), + argument.m_names.end(), ", "); + // If we have a metavar, and one narg - print the metavar + if (!argument.m_metavar.empty() && + argument.m_num_args_range == NArgsRange{1, 1}) { + name_stream << " " << argument.m_metavar; + } else if (!argument.m_metavar.empty() && + argument.m_num_args_range.get_min() == + argument.m_num_args_range.get_max() && + argument.m_metavar.find("> <") != std::string::npos) { + name_stream << " " << argument.m_metavar; + } + } + + // align multiline help message + auto stream_width = stream.width(); + auto name_padding = std::string(name_stream.str().size(), ' '); + auto pos = std::string::size_type{}; + auto prev = std::string::size_type{}; + auto first_line = true; + auto hspace = " "; // minimal space between name and help message + stream << name_stream.str(); + std::string_view help_view(argument.m_help); + while ((pos = argument.m_help.find('\n', prev)) != std::string::npos) { + auto line = help_view.substr(prev, pos - prev + 1); + if (first_line) { + stream << hspace << line; + first_line = false; + } else { + stream.width(stream_width); + stream << name_padding << hspace << line; + } + prev += pos - prev + 1; + } + if (first_line) { + stream << hspace << argument.m_help; + } else { + auto leftover = + help_view.substr(prev, argument.m_help.size() - prev); + if (!leftover.empty()) { + stream.width(stream_width); + stream << name_padding << hspace << leftover; + } + } + + // print nargs spec + if (!argument.m_help.empty()) { + stream << " "; + } + stream << argument.m_num_args_range; + + bool add_space = false; + if (argument.m_default_value.has_value() && + argument.m_num_args_range != NArgsRange{0, 0}) { + stream << "[default: " << argument.m_default_value_repr << "]"; + add_space = true; + } else if (argument.m_is_required) { + stream << "[required]"; + add_space = true; + } + if (argument.m_is_repeatable) { + if (add_space) { + stream << " "; + } + stream << "[may be repeated]"; + } + stream << "\n"; + return stream; + } + + template + bool operator!=(const T& rhs) const { + return !(*this == rhs); + } + + /* + * Compare to an argument value of known type + * @throws std::logic_error in case of incompatible types + */ + template + bool operator==(const T& rhs) const { + if constexpr (!details::IsContainer) { + return get() == rhs; + } else { + using ValueType = typename T::value_type; + auto lhs = get(); + return std::equal(std::begin(lhs), std::end(lhs), std::begin(rhs), + std::end(rhs), [](const auto& a, const auto& b) { + return std::any_cast(a) == + b; + }); + } + } + + /* + * positional: + * _empty_ + * '-' + * '-' decimal-literal + * !'-' anything + */ + static bool is_positional(std::string_view name, + std::string_view prefix_chars) { + auto first = lookahead(name); + + if (first == eof) { + return true; + } + if (prefix_chars.find(static_cast(first)) != + std::string_view::npos) { + name.remove_prefix(1); + if (name.empty()) { + return true; + } + return is_decimal_literal(name); + } + return true; + } + +private: + class NArgsRange { + std::size_t m_min; + std::size_t m_max; + + public: + NArgsRange(std::size_t minimum, std::size_t maximum) + : m_min(minimum), m_max(maximum) { + if (minimum > maximum) { + throw std::logic_error( + "Range of number of arguments is invalid"); + } + } + + bool contains(std::size_t value) const { + return value >= m_min && value <= m_max; + } + + bool is_exact() const { + return m_min == m_max; + } + + bool is_right_bounded() const { + return m_max < (std::numeric_limits::max)(); + } + + std::size_t get_min() const { + return m_min; + } + + std::size_t get_max() const { + return m_max; + } + + // Print help message + friend auto operator<<(std::ostream& stream, const NArgsRange& range) + -> std::ostream& { + if (range.m_min == range.m_max) { + if (range.m_min != 0 && range.m_min != 1) { + stream << "[nargs: " << range.m_min << "] "; + } + } else { + if (range.m_max == (std::numeric_limits::max)()) { + stream << "[nargs: " << range.m_min << " or more] "; + } else { + stream << "[nargs=" << range.m_min << ".." << range.m_max + << "] "; + } + } + return stream; + } + + bool operator==(const NArgsRange& rhs) const { + return rhs.m_min == m_min && rhs.m_max == m_max; + } + + bool operator!=(const NArgsRange& rhs) const { + return !(*this == rhs); + } + }; + + void throw_nargs_range_validation_error() const { + std::stringstream stream; + if (!m_used_name.empty()) { + stream << m_used_name << ": "; + } else { + stream << m_names.front() << ": "; + } + if (m_num_args_range.is_exact()) { + stream << m_num_args_range.get_min(); + } else if (m_num_args_range.is_right_bounded()) { + stream << m_num_args_range.get_min() << " to " + << m_num_args_range.get_max(); + } else { + stream << m_num_args_range.get_min() << " or more"; + } + stream << " argument(s) expected. " << m_values.size() << " provided."; + throw std::runtime_error(stream.str()); + } + + void throw_required_arg_not_used_error() const { + std::stringstream stream; + stream << m_names.front() << ": required."; + throw std::runtime_error(stream.str()); + } + + void throw_required_arg_no_value_provided_error() const { + std::stringstream stream; + stream << m_used_name << ": no value provided."; + throw std::runtime_error(stream.str()); + } + + static constexpr int eof = std::char_traits::eof(); + + static auto lookahead(std::string_view s) -> int { + if (s.empty()) { + return eof; + } + return static_cast(static_cast(s[0])); + } + + /* + * decimal-literal: + * '0' + * nonzero-digit digit-sequence_opt + * integer-part fractional-part + * fractional-part + * integer-part '.' exponent-part_opt + * integer-part exponent-part + * + * integer-part: + * digit-sequence + * + * fractional-part: + * '.' post-decimal-point + * + * post-decimal-point: + * digit-sequence exponent-part_opt + * + * exponent-part: + * 'e' post-e + * 'E' post-e + * + * post-e: + * sign_opt digit-sequence + * + * sign: one of + * '+' '-' + */ + static bool is_decimal_literal(std::string_view s) { + auto is_digit = [](auto c) constexpr { + switch (c) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + return true; + default: + return false; + } + }; + + // precondition: we have consumed or will consume at least one digit + auto consume_digits = [=](std::string_view sd) { + // NOLINTNEXTLINE(readability-qualified-auto) + auto it = std::find_if_not(std::begin(sd), std::end(sd), is_digit); + return sd.substr(static_cast(it - std::begin(sd))); + }; + + switch (lookahead(s)) { + case '0': { + s.remove_prefix(1); + if (s.empty()) { + return true; + } + goto integer_part; + } + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': { + s = consume_digits(s); + if (s.empty()) { + return true; + } + goto integer_part_consumed; + } + case '.': { + s.remove_prefix(1); + goto post_decimal_point; + } + default: + return false; + } + + integer_part: + s = consume_digits(s); + integer_part_consumed: + switch (lookahead(s)) { + case '.': { + s.remove_prefix(1); + if (is_digit(lookahead(s))) { + goto post_decimal_point; + } else { + goto exponent_part_opt; + } + } + case 'e': + case 'E': { + s.remove_prefix(1); + goto post_e; + } + default: + return false; + } + + post_decimal_point: + if (is_digit(lookahead(s))) { + s = consume_digits(s); + goto exponent_part_opt; + } + return false; + + exponent_part_opt: + switch (lookahead(s)) { + case eof: + return true; + case 'e': + case 'E': { + s.remove_prefix(1); + goto post_e; + } + default: + return false; + } + + post_e: + switch (lookahead(s)) { + case '-': + case '+': + s.remove_prefix(1); + } + if (is_digit(lookahead(s))) { + s = consume_digits(s); + return s.empty(); + } + return false; + } + + static bool is_optional(std::string_view name, + std::string_view prefix_chars) { + return !is_positional(name, prefix_chars); + } + + /* + * Get argument value given a type + * @throws std::logic_error in case of incompatible types + */ + template + T get() const { + if (!m_values.empty()) { + if constexpr (details::IsContainer) { + return any_cast_container(m_values); + } else { + return std::any_cast(m_values.front()); + } + } + if (m_default_value.has_value()) { + return std::any_cast(m_default_value); + } + if constexpr (details::IsContainer) { + if (!m_accepts_optional_like_value) { + return any_cast_container(m_values); + } + } + + throw std::logic_error("No value provided for '" + m_names.back() + + "'."); + } + + /* + * Get argument value given a type. + * @pre The object has no default value. + * @returns The stored value if any, std::nullopt otherwise. + */ + template + auto present() const -> std::optional { + if (m_default_value.has_value()) { + throw std::logic_error( + "Argument with default value always presents"); + } + if (m_values.empty()) { + return std::nullopt; + } + if constexpr (details::IsContainer) { + return any_cast_container(m_values); + } + return std::any_cast(m_values.front()); + } + + template + static auto any_cast_container(const std::vector& operand) -> T { + using ValueType = typename T::value_type; + + T result; + std::transform(std::begin(operand), std::end(operand), + std::back_inserter(result), [](const auto& value) { + return std::any_cast(value); + }); + return result; + } + + void set_usage_newline_counter(int i) { + m_usage_newline_counter = i; + } + + void set_group_idx(std::size_t i) { + m_group_idx = i; + } + + std::vector m_names; + std::string_view m_used_name; + std::string m_help; + std::string m_metavar; + std::any m_default_value; + std::string m_default_value_repr; + std::optional + m_default_value_str; // used for checking default_value against choices + std::any m_implicit_value; + std::optional> m_choices{std::nullopt}; + using valued_action = std::function; + using void_action = std::function; + std::vector> m_actions; + std::variant m_default_action{ + std::in_place_type, [](const std::string& value) { + return value; + }}; + std::vector m_values; + NArgsRange m_num_args_range{1, 1}; + // Bit field of bool values. Set default value in ctor. + bool m_accepts_optional_like_value : 1; + bool m_is_optional : 1; + bool m_is_required : 1; + bool m_is_repeatable : 1; + bool m_is_used : 1; + bool m_is_hidden : 1; // if set, does not appear in usage or help + std::string_view m_prefix_chars; // ArgumentParser has the prefix_chars + int m_usage_newline_counter = 0; + std::size_t m_group_idx = 0; +}; + +class ArgumentParser { +public: + explicit ArgumentParser(std::string program_name = {}, + std::string version = "1.0", + default_arguments add_args = default_arguments::all, + bool exit_on_default_arguments = true, + std::ostream& os = std::cout) + : m_program_name(std::move(program_name)), + m_version(std::move(version)), + m_exit_on_default_arguments(exit_on_default_arguments), + m_parser_path(m_program_name) { + if ((add_args & default_arguments::help) == default_arguments::help) { + add_argument("-h", "--help") + .action([&](const auto& /*unused*/) { + os << help().str(); + if (m_exit_on_default_arguments) { + std::exit(0); + } + }) + .default_value(false) + .help("shows help message and exits") + .implicit_value(true) + .nargs(0); + } + if ((add_args & default_arguments::version) == + default_arguments::version) { + add_argument("-v", "--version") + .action([&](const auto& /*unused*/) { + os << m_version << std::endl; + if (m_exit_on_default_arguments) { + std::exit(0); + } + }) + .default_value(false) + .help("prints version information and exits") + .implicit_value(true) + .nargs(0); + } + } + + ~ArgumentParser() = default; + + // ArgumentParser is meant to be used in a single function. + // Setup everything and parse arguments in one place. + // + // ArgumentParser internally uses std::string_views, + // references, iterators, etc. + // Many of these elements become invalidated after a copy or move. + ArgumentParser(const ArgumentParser& other) = delete; + ArgumentParser& operator=(const ArgumentParser& other) = delete; + ArgumentParser(ArgumentParser&&) noexcept = delete; + ArgumentParser& operator=(ArgumentParser&&) = delete; + + explicit operator bool() const { + auto arg_used = std::any_of(m_argument_map.cbegin(), + m_argument_map.cend(), [](auto& it) { + return it.second->m_is_used; + }); + auto subparser_used = std::any_of( + m_subparser_used.cbegin(), m_subparser_used.cend(), [](auto& it) { + return it.second; + }); + + return m_is_parsed && (arg_used || subparser_used); + } + + // Parameter packing + // Call add_argument with variadic number of string arguments + template + Argument& add_argument(Targs... f_args) { + using array_of_sv = std::array; + auto argument = m_optional_arguments.emplace( + std::cend(m_optional_arguments), m_prefix_chars, + array_of_sv{f_args...}); + + if (!argument->m_is_optional) { + m_positional_arguments.splice(std::cend(m_positional_arguments), + m_optional_arguments, argument); + } + argument->set_usage_newline_counter(m_usage_newline_counter); + argument->set_group_idx(m_group_names.size()); + + index_argument(argument); + return *argument; + } + + class MutuallyExclusiveGroup { + friend class ArgumentParser; + + public: + MutuallyExclusiveGroup() = delete; + + explicit MutuallyExclusiveGroup(ArgumentParser& parent, + bool required = false) + : m_parent(parent), m_required(required), m_elements({}) { + } + + MutuallyExclusiveGroup(const MutuallyExclusiveGroup& other) = delete; + MutuallyExclusiveGroup& + operator=(const MutuallyExclusiveGroup& other) = delete; + + MutuallyExclusiveGroup(MutuallyExclusiveGroup&& other) noexcept + : m_parent(other.m_parent), m_required(other.m_required), + m_elements(std::move(other.m_elements)) { + other.m_elements.clear(); + } + + template + Argument& add_argument(Targs... f_args) { + auto& argument = + m_parent.add_argument(std::forward(f_args)...); + m_elements.push_back(&argument); + argument.set_usage_newline_counter( + m_parent.m_usage_newline_counter); + argument.set_group_idx(m_parent.m_group_names.size()); + return argument; + } + + private: + ArgumentParser& m_parent; + bool m_required{false}; + std::vector m_elements{}; + }; + + MutuallyExclusiveGroup& + add_mutually_exclusive_group(bool required = false) { + m_mutually_exclusive_groups.emplace_back(*this, required); + return m_mutually_exclusive_groups.back(); + } + + // Parameter packed add_parents method + // Accepts a variadic number of ArgumentParser objects + template + ArgumentParser& add_parents(const Targs&... f_args) { + for (const ArgumentParser& parent_parser : {std::ref(f_args)...}) { + for (const auto& argument : parent_parser.m_positional_arguments) { + auto it = m_positional_arguments.insert( + std::cend(m_positional_arguments), argument); + index_argument(it); + } + for (const auto& argument : parent_parser.m_optional_arguments) { + auto it = m_optional_arguments.insert( + std::cend(m_optional_arguments), argument); + index_argument(it); + } + } + return *this; + } + + // Ask for the next optional arguments to be displayed on a separate + // line in usage() output. Only effective if set_usage_max_line_width() is + // also used. + ArgumentParser& add_usage_newline() { + ++m_usage_newline_counter; + return *this; + } + + // Ask for the next optional arguments to be displayed in a separate section + // in usage() and help (<< *this) output. + // For usage(), this is only effective if set_usage_max_line_width() is + // also used. + ArgumentParser& add_group(std::string group_name) { + m_group_names.emplace_back(std::move(group_name)); + return *this; + } + + ArgumentParser& add_description(std::string description) { + m_description = std::move(description); + return *this; + } + + ArgumentParser& add_epilog(std::string epilog) { + m_epilog = std::move(epilog); + return *this; + } + + // Add a un-documented/hidden alias for an argument. + // Ideally we'd want this to be a method of Argument, but Argument + // does not own its owing ArgumentParser. + ArgumentParser& add_hidden_alias_for(Argument& arg, + std::string_view alias) { + for (auto it = m_optional_arguments.begin(); + it != m_optional_arguments.end(); ++it) { + if (&(*it) == &arg) { + m_argument_map.insert_or_assign(std::string(alias), it); + return *this; + } + } + throw std::logic_error( + "Argument is not an optional argument of this parser"); + } + + /* Getter for arguments and subparsers. + * @throws std::logic_error in case of an invalid argument or subparser name + */ + template + T& at(std::string_view name) { + if constexpr (std::is_same_v) { + return (*this)[name]; + } else { + std::string str_name(name); + auto subparser_it = m_subparser_map.find(str_name); + if (subparser_it != m_subparser_map.end()) { + return subparser_it->second->get(); + } + throw std::logic_error("No such subparser: " + str_name); + } + } + + ArgumentParser& set_prefix_chars(std::string prefix_chars) { + m_prefix_chars = std::move(prefix_chars); + return *this; + } + + ArgumentParser& set_assign_chars(std::string assign_chars) { + m_assign_chars = std::move(assign_chars); + return *this; + } + + /* Call parse_args_internal - which does all the work + * Then, validate the parsed arguments + * This variant is used mainly for testing + * @throws std::runtime_error in case of any invalid argument + */ + void parse_args(const std::vector& arguments) { + parse_args_internal(arguments); + // Check if all arguments are parsed + for ([[maybe_unused]] const auto& [unused, argument] : m_argument_map) { + argument->validate(); + } + + // Check each mutually exclusive group and make sure + // there are no constraint violations + for (const auto& group : m_mutually_exclusive_groups) { + auto mutex_argument_used{false}; + Argument* mutex_argument_it{nullptr}; + for (Argument* arg : group.m_elements) { + if (!mutex_argument_used && arg->m_is_used) { + mutex_argument_used = true; + mutex_argument_it = arg; + } else if (mutex_argument_used && arg->m_is_used) { + // Violation + throw std::runtime_error( + "Argument '" + arg->get_usage_full() + + "' not allowed with '" + + mutex_argument_it->get_usage_full() + "'"); + } + } + + if (!mutex_argument_used && group.m_required) { + // at least one argument from the group is + // required + std::string argument_names{}; + std::size_t i = 0; + std::size_t size = group.m_elements.size(); + for (Argument* arg : group.m_elements) { + if (i + 1 == size) { + // last + argument_names += std::string("'") + + arg->get_usage_full() + + std::string("' "); + } else { + argument_names += std::string("'") + + arg->get_usage_full() + + std::string("' or "); + } + i += 1; + } + throw std::runtime_error("One of the arguments " + + argument_names + "is required"); + } + } + } + + /* Call parse_known_args_internal - which does all the work + * Then, validate the parsed arguments + * This variant is used mainly for testing + * @throws std::runtime_error in case of any invalid argument + */ + std::vector + parse_known_args(const std::vector& arguments) { + auto unknown_arguments = parse_known_args_internal(arguments); + // Check if all arguments are parsed + for ([[maybe_unused]] const auto& [unused, argument] : m_argument_map) { + argument->validate(); + } + return unknown_arguments; + } + + /* Main entry point for parsing command-line arguments using this + * ArgumentParser + * @throws std::runtime_error in case of any invalid argument + */ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) + void parse_args(int argc, const char* const argv[]) { + parse_args({argv, argv + argc}); + } + + /* Main entry point for parsing command-line arguments using this + * ArgumentParser + * @throws std::runtime_error in case of any invalid argument + */ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) + auto parse_known_args(int argc, const char* const argv[]) { + return parse_known_args({argv, argv + argc}); + } + + /* Getter for options with default values. + * @throws std::logic_error if parse_args() has not been previously called + * @throws std::logic_error if there is no such option + * @throws std::logic_error if the option has no value + * @throws std::bad_any_cast if the option is not of type T + */ + template + T get(std::string_view arg_name) const { + if (!m_is_parsed) { + throw std::logic_error( + "Nothing parsed, no arguments are available."); + } + return (*this)[arg_name].get(); + } + + /* Getter for options without default values. + * @pre The option has no default value. + * @throws std::logic_error if there is no such option + * @throws std::bad_any_cast if the option is not of type T + */ + template + auto present(std::string_view arg_name) const -> std::optional { + return (*this)[arg_name].present(); + } + + /* Getter that returns true for user-supplied options. Returns false if not + * user-supplied, even with a default value. + */ + auto is_used(std::string_view arg_name) const { + return (*this)[arg_name].m_is_used; + } + + /* Getter that returns true if a subcommand is used. + */ + auto is_subcommand_used(std::string_view subcommand_name) const { + return m_subparser_used.at(std::string(subcommand_name)); + } + + /* Getter that returns true if a subcommand is used. + */ + auto is_subcommand_used(const ArgumentParser& subparser) const { + return is_subcommand_used(subparser.m_program_name); + } + + /* Indexing operator. Return a reference to an Argument object + * Used in conjunction with Argument.operator== e.g., parser["foo"] == true + * @throws std::logic_error in case of an invalid argument name + */ + Argument& operator[](std::string_view arg_name) const { + std::string name(arg_name); + auto it = m_argument_map.find(name); + if (it != m_argument_map.end()) { + return *(it->second); + } + if (!is_valid_prefix_char(arg_name.front())) { + const auto legal_prefix_char = get_any_valid_prefix_char(); + const auto prefix = std::string(1, legal_prefix_char); + + // "-" + arg_name + name = prefix + name; + it = m_argument_map.find(name); + if (it != m_argument_map.end()) { + return *(it->second); + } + // "--" + arg_name + name = prefix + name; + it = m_argument_map.find(name); + if (it != m_argument_map.end()) { + return *(it->second); + } + } + throw std::logic_error("No such argument: " + std::string(arg_name)); + } + + // Print help message + friend auto operator<<(std::ostream& stream, const ArgumentParser& parser) + -> std::ostream& { + stream.setf(std::ios_base::left); + + auto longest_arg_length = parser.get_length_of_longest_argument(); + + stream << parser.usage() << "\n\n"; + + if (!parser.m_description.empty()) { + stream << parser.m_description << "\n\n"; + } + + const bool has_visible_positional_args = + std::find_if(parser.m_positional_arguments.begin(), + parser.m_positional_arguments.end(), + [](const auto& argument) { + return !argument.m_is_hidden; + }) != parser.m_positional_arguments.end(); + if (has_visible_positional_args) { + stream << "Positional arguments:\n"; + } + + for (const auto& argument : parser.m_positional_arguments) { + if (!argument.m_is_hidden) { + stream.width(static_cast(longest_arg_length)); + stream << argument; + } + } + + if (!parser.m_optional_arguments.empty()) { + stream << (!has_visible_positional_args ? "" : "\n") + << "Optional arguments:\n"; + } + + for (const auto& argument : parser.m_optional_arguments) { + if (argument.m_group_idx == 0 && !argument.m_is_hidden) { + stream.width(static_cast(longest_arg_length)); + stream << argument; + } + } + + for (size_t i_group = 0; i_group < parser.m_group_names.size(); + ++i_group) { + stream << "\n" + << parser.m_group_names[i_group] << " (detailed usage):\n"; + for (const auto& argument : parser.m_optional_arguments) { + if (argument.m_group_idx == i_group + 1 && + !argument.m_is_hidden) { + stream.width( + static_cast(longest_arg_length)); + stream << argument; + } + } + } + + bool has_visible_subcommands = + std::any_of(parser.m_subparser_map.begin(), + parser.m_subparser_map.end(), [](auto& p) { + return !p.second->get().m_suppress; + }); + + if (has_visible_subcommands) { + stream << (parser.m_positional_arguments.empty() + ? (parser.m_optional_arguments.empty() ? "" : "\n") + : "\n") + << "Subcommands:\n"; + for (const auto& [command, subparser] : parser.m_subparser_map) { + if (subparser->get().m_suppress) { + continue; + } + + stream << std::setw(2) << " "; + stream << std::setw(static_cast(longest_arg_length - 2)) + << command; + stream << " " << subparser->get().m_description << "\n"; + } + } + + if (!parser.m_epilog.empty()) { + stream << '\n'; + stream << parser.m_epilog << "\n\n"; + } + + return stream; + } + + // Format help message + auto help() const -> std::stringstream { + std::stringstream out; + out << *this; + return out; + } + + // Sets the maximum width for a line of the Usage message + ArgumentParser& set_usage_max_line_width(size_t w) { + this->m_usage_max_line_width = w; + return *this; + } + + // Asks to display arguments of mutually exclusive group on separate lines + // in the Usage message + ArgumentParser& set_usage_break_on_mutex() { + this->m_usage_break_on_mutex = true; + return *this; + } + + // Format usage part of help only + auto usage() const -> std::string { + std::stringstream stream; + + std::string curline("Usage: "); + curline += this->m_parser_path; + const bool multiline_usage = this->m_usage_max_line_width < + (std::numeric_limits::max)(); + const size_t indent_size = curline.size(); + + const auto deal_with_options_of_group = [&](std::size_t group_idx) { + bool found_options = false; + // Add any options inline here + const MutuallyExclusiveGroup* cur_mutex = nullptr; + int usage_newline_counter = -1; + for (const auto& argument : this->m_optional_arguments) { + if (argument.m_is_hidden) { + continue; + } + if (multiline_usage) { + if (argument.m_group_idx != group_idx) { + continue; + } + if (usage_newline_counter != + argument.m_usage_newline_counter) { + if (usage_newline_counter >= 0) { + if (curline.size() > indent_size) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + } + usage_newline_counter = + argument.m_usage_newline_counter; + } + } + found_options = true; + const std::string arg_inline_usage = + argument.get_inline_usage(); + const MutuallyExclusiveGroup* arg_mutex = + get_belonging_mutex(&argument); + if ((cur_mutex != nullptr) && (arg_mutex == nullptr)) { + curline += ']'; + if (this->m_usage_break_on_mutex) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + } else if ((cur_mutex == nullptr) && (arg_mutex != nullptr)) { + if ((this->m_usage_break_on_mutex && + curline.size() > indent_size) || + curline.size() + 3 + arg_inline_usage.size() > + this->m_usage_max_line_width) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + curline += " ["; + } else if ((cur_mutex != nullptr) && (arg_mutex != nullptr)) { + if (cur_mutex != arg_mutex) { + curline += ']'; + if (this->m_usage_break_on_mutex || + curline.size() + 3 + arg_inline_usage.size() > + this->m_usage_max_line_width) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + curline += " ["; + } else { + curline += '|'; + } + } + cur_mutex = arg_mutex; + if (curline.size() != indent_size && + curline.size() + 1 + arg_inline_usage.size() > + this->m_usage_max_line_width) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + curline += " "; + } else if (cur_mutex == nullptr) { + curline += " "; + } + curline += arg_inline_usage; + } + if (cur_mutex != nullptr) { + curline += ']'; + } + return found_options; + }; + + const bool found_options = deal_with_options_of_group(0); + + if (found_options && multiline_usage && + !this->m_positional_arguments.empty()) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + // Put positional arguments after the optionals + for (const auto& argument : this->m_positional_arguments) { + if (argument.m_is_hidden) { + continue; + } + const std::string pos_arg = !argument.m_metavar.empty() + ? argument.m_metavar + : argument.m_names.front(); + if (curline.size() + 1 + pos_arg.size() > + this->m_usage_max_line_width) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + curline += " "; + if (argument.m_num_args_range.get_min() == 0 && + !argument.m_num_args_range.is_right_bounded()) { + curline += "["; + curline += pos_arg; + curline += "]..."; + } else if (argument.m_num_args_range.get_min() == 1 && + !argument.m_num_args_range.is_right_bounded()) { + curline += pos_arg; + curline += "..."; + } else { + curline += pos_arg; + } + } + + if (multiline_usage) { + // Display options of other groups + for (std::size_t i = 0; i < m_group_names.size(); ++i) { + stream << curline << std::endl << std::endl; + stream << m_group_names[i] << ":" << std::endl; + curline = std::string(indent_size, ' '); + deal_with_options_of_group(i + 1); + } + } + + stream << curline; + + // Put subcommands after positional arguments + if (!m_subparser_map.empty()) { + stream << " {"; + std::size_t i{0}; + for (const auto& [command, subparser] : m_subparser_map) { + if (subparser->get().m_suppress) { + continue; + } + + if (i == 0) { + stream << command; + } else { + stream << "," << command; + } + ++i; + } + stream << "}"; + } + + return stream.str(); + } + + // Printing the one and only help message + // I've stuck with a simple message format, nothing fancy. + [[deprecated( + "Use cout << program; instead. See also help().")]] std::string + print_help() const { + auto out = help(); + std::cout << out.rdbuf(); + return out.str(); + } + + void add_subparser(ArgumentParser& parser) { + parser.m_parser_path = m_program_name + " " + parser.m_program_name; + auto it = m_subparsers.emplace(std::cend(m_subparsers), parser); + m_subparser_map.insert_or_assign(parser.m_program_name, it); + m_subparser_used.insert_or_assign(parser.m_program_name, false); + } + + void set_suppress(bool suppress) { + m_suppress = suppress; + } + +protected: + const MutuallyExclusiveGroup* + get_belonging_mutex(const Argument* arg) const { + for (const auto& mutex : m_mutually_exclusive_groups) { + if (std::find(mutex.m_elements.begin(), mutex.m_elements.end(), + arg) != mutex.m_elements.end()) { + return &mutex; + } + } + return nullptr; + } + + bool is_valid_prefix_char(char c) const { + return m_prefix_chars.find(c) != std::string::npos; + } + + char get_any_valid_prefix_char() const { + return m_prefix_chars[0]; + } + + /* + * Pre-process this argument list. Anything starting with "--", that + * contains an =, where the prefix before the = has an entry in the + * options table, should be split. + */ + std::vector + preprocess_arguments(const std::vector& raw_arguments) const { + std::vector arguments{}; + for (const auto& arg : raw_arguments) { + + const auto argument_starts_with_prefix_chars = + [this](const std::string& a) -> bool { + if (!a.empty()) { + + const auto legal_prefix = [this](char c) -> bool { + return m_prefix_chars.find(c) != std::string::npos; + }; + + // Windows-style + // if '/' is a legal prefix char + // then allow single '/' followed by argument name, followed + // by an assign char, e.g., ':' e.g., 'test.exe /A:Foo' + const auto windows_style = legal_prefix('/'); + + if (windows_style) { + if (legal_prefix(a[0])) { + return true; + } + } else { + // Slash '/' is not a legal prefix char + // For all other characters, only support long arguments + // i.e., the argument must start with 2 prefix chars, + // e.g, + // '--foo' e,g, './test --foo=Bar -DARG=yes' + if (a.size() > 1) { + return (legal_prefix(a[0]) && legal_prefix(a[1])); + } + } + } + return false; + }; + + // Check that: + // - We don't have an argument named exactly this + // - The argument starts with a prefix char, e.g., "--" + // - The argument contains an assign char, e.g., "=" + auto assign_char_pos = arg.find_first_of(m_assign_chars); + + if (m_argument_map.find(arg) == m_argument_map.end() && + argument_starts_with_prefix_chars(arg) && + assign_char_pos != std::string::npos) { + // Get the name of the potential option, and check it exists + std::string opt_name = arg.substr(0, assign_char_pos); + if (m_argument_map.find(opt_name) != m_argument_map.end()) { + // This is the name of an option! Split it into two parts + arguments.push_back(std::move(opt_name)); + arguments.push_back(arg.substr(assign_char_pos + 1)); + continue; + } + } + // If we've fallen through to here, then it's a standard argument + arguments.push_back(arg); + } + return arguments; + } + + /* + * @throws std::runtime_error in case of any invalid argument + */ + void parse_args_internal(const std::vector& raw_arguments) { + auto arguments = preprocess_arguments(raw_arguments); + if (m_program_name.empty() && !arguments.empty()) { + m_program_name = arguments.front(); + } + auto end = std::end(arguments); + auto positional_argument_it = std::begin(m_positional_arguments); + for (auto it = std::next(std::begin(arguments)); it != end;) { + const auto& current_argument = *it; + if (Argument::is_positional(current_argument, m_prefix_chars)) { + if (positional_argument_it == + std::end(m_positional_arguments)) { + + // Check sub-parsers + auto subparser_it = m_subparser_map.find(current_argument); + if (subparser_it != m_subparser_map.end()) { + + // build list of remaining args + const auto unprocessed_arguments = + std::vector(it, end); + + // invoke subparser + m_is_parsed = true; + m_subparser_used[current_argument] = true; + return subparser_it->second->get().parse_args( + unprocessed_arguments); + } + + if (m_positional_arguments.empty()) { + + // Ask the user if they argument they provided was a + // typo for some sub-parser, e.g., user provided `git + // totes` instead of `git notes` + if (!m_subparser_map.empty()) { + throw std::runtime_error( + "Failed to parse '" + current_argument + + "', did you mean '" + + std::string{details::get_most_similar_string( + m_subparser_map, current_argument)} + + "'"); + } + + // Ask the user if they meant to use a specific optional + // argument + if (!m_optional_arguments.empty()) { + for (const auto& opt : m_optional_arguments) { + if (!opt.m_implicit_value.has_value()) { + // not a flag, requires a value + if (!opt.m_is_used) { + throw std::runtime_error( + "Zero positional arguments " + "expected, did you mean " + + opt.get_usage_full()); + } + } + } + + throw std::runtime_error( + "Zero positional arguments expected"); + } else { + throw std::runtime_error( + "Zero positional arguments expected"); + } + } else { + throw std::runtime_error( + "Maximum number of positional arguments " + "exceeded, failed to parse '" + + current_argument + "'"); + } + } + auto argument = positional_argument_it++; + + // Deal with the situation of ... + // + if (argument->m_num_args_range.get_min() == 1 && + argument->m_num_args_range.get_max() == + (std::numeric_limits::max)() && + positional_argument_it != + std::end(m_positional_arguments) && + std::next(positional_argument_it) == + std::end(m_positional_arguments) && + positional_argument_it->m_num_args_range.get_min() == 1 && + positional_argument_it->m_num_args_range.get_max() == 1) { + if (std::next(it) != end) { + positional_argument_it->consume(std::prev(end), end); + end = std::prev(end); + } else { + throw std::runtime_error( + "Missing " + + positional_argument_it->m_names.front()); + } + } + + it = argument->consume(it, end); + continue; + } + + auto arg_map_it = m_argument_map.find(current_argument); + if (arg_map_it != m_argument_map.end()) { + auto argument = arg_map_it->second; + it = argument->consume(std::next(it), end, arg_map_it->first); + } else if (const auto& compound_arg = current_argument; + compound_arg.size() > 1 && + is_valid_prefix_char(compound_arg[0]) && + !is_valid_prefix_char(compound_arg[1])) { + ++it; + for (std::size_t j = 1; j < compound_arg.size(); j++) { + auto hypothetical_arg = std::string{'-', compound_arg[j]}; + auto arg_map_it2 = m_argument_map.find(hypothetical_arg); + if (arg_map_it2 != m_argument_map.end()) { + auto argument = arg_map_it2->second; + it = argument->consume(it, end, arg_map_it2->first); + } else { + throw std::runtime_error("Unknown argument: " + + current_argument); + } + } + } else { + throw std::runtime_error("Unknown argument: " + + current_argument); + } + } + m_is_parsed = true; + } + + /* + * Like parse_args_internal but collects unused args into a vector + */ + std::vector + parse_known_args_internal(const std::vector& raw_arguments) { + auto arguments = preprocess_arguments(raw_arguments); + + std::vector unknown_arguments{}; + + if (m_program_name.empty() && !arguments.empty()) { + m_program_name = arguments.front(); + } + auto end = std::end(arguments); + auto positional_argument_it = std::begin(m_positional_arguments); + for (auto it = std::next(std::begin(arguments)); it != end;) { + const auto& current_argument = *it; + if (Argument::is_positional(current_argument, m_prefix_chars)) { + if (positional_argument_it == + std::end(m_positional_arguments)) { + + // Check sub-parsers + auto subparser_it = m_subparser_map.find(current_argument); + if (subparser_it != m_subparser_map.end()) { + + // build list of remaining args + const auto unprocessed_arguments = + std::vector(it, end); + + // invoke subparser + m_is_parsed = true; + m_subparser_used[current_argument] = true; + return subparser_it->second->get() + .parse_known_args_internal(unprocessed_arguments); + } + + // save current argument as unknown and go to next argument + unknown_arguments.push_back(current_argument); + ++it; + } else { + // current argument is the value of a positional argument + // consume it + auto argument = positional_argument_it++; + it = argument->consume(it, end); + } + continue; + } + + auto arg_map_it = m_argument_map.find(current_argument); + if (arg_map_it != m_argument_map.end()) { + auto argument = arg_map_it->second; + it = argument->consume(std::next(it), end, arg_map_it->first); + } else if (const auto& compound_arg = current_argument; + compound_arg.size() > 1 && + is_valid_prefix_char(compound_arg[0]) && + !is_valid_prefix_char(compound_arg[1])) { + ++it; + for (std::size_t j = 1; j < compound_arg.size(); j++) { + auto hypothetical_arg = std::string{'-', compound_arg[j]}; + auto arg_map_it2 = m_argument_map.find(hypothetical_arg); + if (arg_map_it2 != m_argument_map.end()) { + auto argument = arg_map_it2->second; + it = argument->consume(it, end, arg_map_it2->first); + } else { + unknown_arguments.push_back(current_argument); + break; + } + } + } else { + // current argument is an optional-like argument that is unknown + // save it and move to next argument + unknown_arguments.push_back(current_argument); + ++it; + } + } + m_is_parsed = true; + return unknown_arguments; + } + + // Used by print_help. + std::size_t get_length_of_longest_argument() const { + if (m_argument_map.empty()) { + return 0; + } + std::size_t max_size = 0; + for ([[maybe_unused]] const auto& [unused, argument] : m_argument_map) { + max_size = std::max(max_size, + argument->get_arguments_length()); + } + for ([[maybe_unused]] const auto& [command, unused] : m_subparser_map) { + max_size = std::max(max_size, command.size()); + } + return max_size; + } + + using argument_it = std::list::iterator; + using mutex_group_it = std::vector::iterator; + using argument_parser_it = + std::list>::iterator; + + void index_argument(argument_it it) { + for (const auto& name : std::as_const(it->m_names)) { + m_argument_map.insert_or_assign(name, it); + } + } + + std::string m_program_name; + std::string m_version; + std::string m_description; + std::string m_epilog; + bool m_exit_on_default_arguments = true; + std::string m_prefix_chars{"-"}; + std::string m_assign_chars{"="}; + bool m_is_parsed = false; + std::list m_positional_arguments; + std::list m_optional_arguments; + std::map m_argument_map; + std::string m_parser_path; + std::list> m_subparsers; + std::map m_subparser_map; + std::map m_subparser_used; + std::vector m_mutually_exclusive_groups; + bool m_suppress = false; + std::size_t m_usage_max_line_width = + (std::numeric_limits::max)(); + bool m_usage_break_on_mutex = false; + int m_usage_newline_counter = 0; + std::vector m_group_names; +}; + +} // namespace argparse diff --git a/cpp/examples/toy_homotopy/toy_homotopy.cpp b/cpp/examples/toy_homotopy/toy_homotopy.cpp new file mode 100644 index 0000000..f88be14 --- /dev/null +++ b/cpp/examples/toy_homotopy/toy_homotopy.cpp @@ -0,0 +1,187 @@ +// STL includes +#include + +// Library includes +#include +#include +#include + +// Project includes +#include "argparse.hpp" + + +// +// +// Homotopy definition +// +// + + +/// +/// @brief Helper type implementing necessary functions for PathTracker +/// @details Toy example homotopy: +/// G = [[x1], +/// [x2]] +/// +/// F = [[x1 + x2 ], +/// [x2 + 0.5]] +/// +/// H = (1-t)*G + t*F +/// +/// @details Note that +/// y := [[x1], +/// [x2], +/// [t]] +/// +struct ToyHomotopy { + /// + /// @brief Evaluate H at y + /// + static Eigen::VectorXd evaluate_H(Eigen::VectorXd y) { + Eigen::VectorXd result(2); + + double x1 = y(0); + double x2 = y(1); + double t = y(2); + + result(0) = x1 + t * x2; + result(1) = x2 + t * 0.5; + + return result; + } + + /// + /// @brief Evaluate Jacobian of H at y + /// + static Eigen::MatrixXd evaluate_DH(Eigen::VectorXd y) { + Eigen::MatrixXd result(2, 3); + + double x1 = y(0); + double x2 = y(1); + double t = y(2); + + result(0, 0) = 1; + result(0, 1) = t; + result(0, 2) = x2; + result(1, 0) = 0; + result(1, 1) = 1; + result(1, 2) = 0.5; + + return result; + } +}; + + +// +// +// Perform path tracking +// +// + + +void tracker_example(hccd::Settings settings, std::size_t num_iterations, + std::string output = "temp.csv") { + + hccd::PathTracker tracker{settings}; + + std::ofstream out; + if (output != "") { + out.open(output); + out << "x1b," << "x2b," << "tb," << "x1p," << "x2p," << "tp," << "x1e," + << "x2e," << "te," << "x1n," << "x2n," << "tn," << std::endl; + } + + Eigen::VectorXd y = Eigen::VectorXd::Zero(3); + + for (int i = 0; i < num_iterations; ++i) { + spdlog::info("Iteration {}", i); + + auto res = tracker.transparent_step(y); + if (!res) { + spdlog::error( + "Newton corrector failed to converge on iteration {} ", i); + std::terminate(); + } + + Eigen::VectorXd y_start, y_prime, y_hat_e; + std::tie(y_start, y_prime, y_hat_e, y) = res.value(); + + spdlog::info("y:{}", y); + + if (output != "") { + out << y_start(0) << "," << y_start(1) << "," << y_start(2) << "," + << y_prime(0) << "," << y_prime(1) << "," << y_prime(2) << "," + << y_hat_e(0) << "," << y_hat_e(1) << "," << y_hat_e(2) << "," + << y(0) << "," << y(1) << "," << y(2) << std::endl; + } + } +} + + +// +// +// User Interface +// +// + + +int main(int argc, char* argv[]) { + // Parse command line arguments + + argparse::ArgumentParser program("Homotopy continuation path tracker"); + + program.add_argument("--verbose") + .default_value(false) + .implicit_value(true); + program.add_argument("--euler-step-size") + .help("Step size for Euler predictor") + .default_value(0.05) + .scan<'g', double>(); + program.add_argument("--euler-max-tries") + .help("Maximum number of tries for Euler predictor") + .default_value(5) + .scan<'u', unsigned>(); + program.add_argument("--newton-max-iter") + .help("Maximum number of iterations for Newton corrector") + .default_value(5) + .scan<'u', unsigned>(); + program.add_argument("--newton-convergence-threshold") + .help("Convergence threshold for Newton corrector") + .default_value(0.01) + .scan<'g', double>(); + program.add_argument("-s", "--sigma") + .help("Direction in which the path is traced") + .default_value(1) + .scan<'i', int>(); + + program.add_argument("-o", "--output") + .help("Output csv file") + .default_value(""); + program.add_argument("-n", "--num-iterations") + .help("Number of iterations of the example program to run") + .default_value(20) + .scan<'u', std::size_t>(); + + try { + program.parse_args(argc, argv); + } catch (const std::runtime_error& err) { + spdlog::error("{}", err.what()); + std::terminate(); + } + + // Run program + + if (program["--verbose"] == true) spdlog::set_level(spdlog::level::debug); + + hccd::Settings settings{ + .euler_step_size = program.get("--euler-step-size"), + .euler_max_tries = program.get("--euler-max-tries"), + .newton_max_iter = program.get("--newton-max-iter"), + .newton_convergence_threshold = + program.get("--newton-convergence-threshold"), + .sigma = program.get("--sigma"), + }; + + tracker_example(settings, program.get("--num-iterations"), + program.get("--output")); +} diff --git a/cpp/include/hccd/PathTracker.hpp b/cpp/include/hccd/PathTracker.hpp new file mode 100644 index 0000000..f01a376 --- /dev/null +++ b/cpp/include/hccd/PathTracker.hpp @@ -0,0 +1,156 @@ +#pragma once + + +// STL includes +#include + +// Library includes +#include +#include + +// Project includes +#include "util.hpp" + + +namespace hccd { + + +namespace detail { +template +concept homotopy_c = requires(T) { + { T::evaluate_H(Eigen::VectorXd()) } -> std::same_as; + { T::evaluate_DH(Eigen::VectorXd()) } -> std::same_as; +}; +} // namespace detail + + +/// +/// @brief Settings for PathTracker +/// +struct Settings { + double euler_step_size = 0.05; + unsigned euler_max_tries = 5; + unsigned newton_max_iter = 5; + double newton_convergence_threshold = 0.01; + int sigma = 1; ///< Direction in which the path is traced +}; + +/// +/// @brief Path tracker for the homotopy continuation method +/// @details Uses a predictor-corrector scheme to trace a path defined by a +/// homotopy. +/// @details References: +/// [1] T. Chen and T.-Y. Li, “Homotopy continuation method for solving +/// systems of nonlinear and polynomial equations,” Communications in +/// Information and Systems, vol. 15, no. 2, pp. 119–307, 2015 +/// +/// @tparam homotopy_c Homotopy defining the path +/// +template +class PathTracker { +public: + enum Error { NewtonNotConverged }; + + PathTracker(Settings settings) : m_settings{settings} { + } + + /// + /// @brief Perform one predictor-corrector step + /// + Eigen::VectorXd step(Eigen::VectorXd y) { + auto res = transparent_step(y); + + if (res) { + return res.value().first; + } else { + return std::unexpected(res.error()); + } + } + + /// + /// @brief Perform one predictor-corrector step, returning intermediate + /// results + /// + std::expected, + Error> + transparent_step(Eigen::VectorXd y) { + for (int i = 0; i < m_settings.euler_max_tries; ++i) { + double step_size = m_settings.euler_step_size / (1 << i); + + const auto [y_hat, y_prime] = + perform_euler_predictor_step(y, step_size); + + auto res = perform_newton_corrector_step(y_hat); + + if (res) return {{y, y_prime, y_hat, res.value()}}; + } + + return std::unexpected(Error::NewtonNotConverged); + } + +private: + Settings m_settings; + + std::pair + perform_euler_predictor_step(Eigen::VectorXd y, double step_size) { + /// Obtain y_prime + + Eigen::MatrixXd DH = Homotopy::evaluate_DH(y); + auto qr = DH.transpose().colPivHouseholderQr(); + Eigen::MatrixXd Q = qr.matrixQ(); + Eigen::MatrixXd R = qr.matrixR(); + + Eigen::VectorXd y_prime = Q.col(2); + + spdlog::debug("Q: \t\t{}x{}; det={}", Q.rows(), Q.cols(), + Q.determinant()); + spdlog::debug("R.topRows(2): {}x{}; det={}", R.topRows(2).rows(), + R.topRows(2).cols(), R.topRows(2).determinant()); + + if (sign(Q.determinant() * R.topRows(2).determinant()) != + sign(m_settings.sigma)) + y_prime = -y_prime; + + /// Perform prediction + + Eigen::VectorXd y_hat = y + step_size * y_prime; + + return {y_hat, y_prime}; + } + + std::expected + perform_newton_corrector_step(Eigen::VectorXd y) { + Eigen::VectorXd prev_y = y; + + for (int i = 0; i < m_settings.newton_max_iter; ++i) { + + /// Perform correction + + Eigen::MatrixXd DH = Homotopy::evaluate_DH(y); + Eigen::MatrixXd DH_pinv = + DH.completeOrthogonalDecomposition().pseudoInverse(); + + y = y - DH_pinv * Homotopy::evaluate_H(y); + + /// Check stopping criterion + + spdlog::debug("Newton iteration {}: ||y-prev_y||={}", i, + (y - prev_y).norm()); + if ((y - prev_y).norm() < m_settings.newton_convergence_threshold) + return y; + + prev_y = y; + } + + return std::unexpected(Error::NewtonNotConverged); + } + + template + static int sign(T val) { + return -1 * (val < T(0)) + 1 * (val >= T(0)); + } +}; + + +} // namespace hccd diff --git a/cpp/include/hccd/util.hpp b/cpp/include/hccd/util.hpp new file mode 100644 index 0000000..2434c1a --- /dev/null +++ b/cpp/include/hccd/util.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include +// #include +#include +#include + + +/// @brief Boilerplate code to enable fmtlib to print Eigen::MatrixXd +template <> +struct fmt::formatter : fmt::ostream_formatter {}; +/// @brief Boilerplate code to enable fmtlib to print Eigen::VectorXd +template <> +struct fmt::formatter : fmt::ostream_formatter { + + template + auto format(const Eigen::VectorXd& value, Context& ctx) const + -> decltype(ctx.out()) { + + auto buffer = basic_memory_buffer(); + auto&& formatbuf = + detail::formatbuf>(buffer); + auto&& output = std::basic_ostream(&formatbuf); + output.imbue(std::locale::classic()); + + output << "["; + for (int i = 0; i < value.size(); ++i) { + output << value(i); + if (i < value.size() - 1) output << ", "; + } + output << "]"; + + output.exceptions(std::ios_base::failbit | std::ios_base::badbit); + return formatter, char>::format( + {buffer.data(), buffer.size()}, ctx); + } +}; diff --git a/docs/general_idea.md b/docs/general_idea.md new file mode 100644 index 0000000..73b6fbc --- /dev/null +++ b/docs/general_idea.md @@ -0,0 +1,99 @@ +# Homotopy Continuation + +### Introduction + +The aim of a homotopy method consists in solving a system of N nonlinear +equations in N variables \[1, p.1\]: + +$$ +F(\bm{x}) = 0, \hspace{5mm} F: \mathbb{R}^N \rightarrow \mathbb{R}^N +$$ + +This is achieved by defining a homotopy, or deformation, +$H : \mathbb{R}^N \rightarrow \mathbb{R}^N$ such that + +$$ +H(\bm{x}, 0) = G(\bm{x}), \hspace{5mm} H(\bm{x},1) = F(\bm{x}), +$$ + +where $G: \mathbb{R}^N \rightarrow \mathbb{R}^N$ is trivial, having known zero +points. One can then trace an implicitly defined curve +$\bm{c}(s) \in H^{-1}(0)$ from $(\bm{x}_1,0)$ to $(\bm{x}_2, 1)$ [2, p.122]. +Typically, a convex homotopy such as + +$$ +H(\bm{x}, t) := tG(\bm{x}) + (1-t)F(\bm{x}) +$$ + +is chosen [1, p.2]. + +### Implicit Definition of Solution Curve + +"In the construction of the homotopy H(x, t) = 0, it may seem natural to use t +as the designated parameter for the solution curve [...]. However, this +parametrization has a severe limitation, as t cannot be used as a smooth +parameter in certain situations. For example, [...] at points where ∂H/∂x is +singular the solution curve of H(x, t) = 0 cannot be parametrized by t +directly. Actually, one may avoid such difficulties by considering both x and t +as independent variables and parametrize the smooth curve γ by the arc-length." +[2, p. 125-126]. + +We can show that the solution curve can be implicitly defined as + +$$ +\begin{align} +DH(\bm{y}(s))\cdot \dot{\bm{y}}(s) = 0 \\[.5em] +\text{det}\left(\begin{array}{c} DH(\bm{y}(s)) \\ \dot{\bm{y}}(s)\end{array}\right) = \sigma_0 \\[.5em] +\lVert \dot{\bm{y}}(s) \rVert = 1 \\[.5em] +\bm{y}(0) = (\bm{x}_0, 0) +. +\end{align} +$$ + +where $DH(y)$ is the Jacobian of $H(y)$ [2, p.127] [1, p.9]. Equation (1) +corresponds to $\dot{y}(s)$ being tangent to the solution curve, the purpose of +(2) is to fix the direction in which the curve is traced. + +### Tracing solution curves + +"In principle, any of the available ODE solvers capable of integrating the +above system can be used to trace the curve [...]. However, due to numerical +stability concerns, a more preferable method to trace the curve is the +'prediction-correction scheme'. [...] Among many different choices for the +'predictors' as well as 'correctors', we shall present here a commonly used +combination of the (generalized) Euler’s method for the predictor and Newton’s +method for the corrector." [2, p.127] + +#### Euler's Predictor + +$$ +\hat{\bm{y}} = \bm{y}_0 + \Delta s \cdot \sigma \cdot \Delta \bm{y}, +$$ + +with $\Delta s$ denoting the step size and $\Delta y$ the numerical +approximation of $\dot{y}(s)$ [2, p.128]. + +#### Newton's corrector + +$$ +\bm{y} = \mathcal{N}^k(\hat{\bm{y}}), \hspace{5mm} \mathcal{N}(\hat{\bm{y}}) := \hat{\bm{y}} - (DH(\hat{\bm{y}}))^{+} H(\hat{\bm{y}}). +$$ + +" [...] the shrinking distance +$\lVert \mathcal{N}^j(\hat{\bm{y}}) - \mathcal{N}^{j-1} (\hat{\bm{y}})\rVert$ +between successive points produced by the iterations can be used as a criterion +for convergence. Of course, if the iterations fail to converge, one must go +back to adjust the step size for the Euler’s predictor." [2, p.130] + +______________________________________________________________________ + +## References + +\[1\]: E. L. Allgower and K. Georg, Introduction to numerical continuation +methods. in Classics in applied mathematics, no. 45. Philadelphia, Pa: Society +for Industrial and Applied Mathematics (SIAM, 3600 Market Street, Floor 6, +Philadelphia, PA 19104), 2003. doi: 10.1137/1.9780898719154. + +\[2\]: T. Chen and T.-Y. Li, “Homotopy continuation method for solving systems +of nonlinear and polynomial equations,” Communications in Information and +Systems, vol. 15, no. 2, pp. 119–307, 2015, doi: 10.4310/CIS.2015.v15.n2.a1. diff --git a/python/examples/repetition_code.py b/python/examples/repetition_code.py new file mode 100644 index 0000000..005e405 --- /dev/null +++ b/python/examples/repetition_code.py @@ -0,0 +1,142 @@ +import argparse +import numpy as np +import pandas as pd + +# autopep8: off +import sys +import os +sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../") + +# TODO: How do I import PathTracker and HomotopyGenerator properly? +from hccd import path_tracker, homotopy_generator +# autopep8: on + + +# class RepetitionCodeHomotopy: +# """Helper type implementing necessary functions for PathTracker. +# +# Repetiton code homotopy: +# G = [[x1], +# [x2], +# [x1]] +# +# F = [[1 - x1**2], +# [1 - x2**2], +# [1 - x1*x2]] +# +# H = (1-t)*G + t*F +# +# Note that +# y := [[x1], +# [x2], +# [t]] +# """ +# @staticmethod +# def evaluate_H(y: np.ndarray) -> np.ndarray: +# """Evaluate H at y.""" +# x1 = y[0] +# x2 = y[1] +# t = y[2] +# +# print(y) +# +# result = np.zeros(shape=3) +# result[0] = -t*x1**2 + x1*(1-t) + t +# result[1] = -t*x2**2 + x2*(1-t) + t +# result[2] = -t*x1*x2 + x1*(1-t) + t +# +# return result +# +# @staticmethod +# def evaluate_DH(y: np.ndarray) -> np.ndarray: +# """Evaluate Jacobian of H at y.""" +# x1 = y[0] +# x2 = y[1] +# t = y[2] +# +# result = np.zeros(shape=(3, 3)) +# result[0, 0] = -2*t*x1 + (1-t) +# result[0, 1] = 0 +# result[0, 2] = -x1**2 - x1 + 1 +# result[1, 0] = 0 +# result[1, 1] = -2*t*x2 + (1-t) +# result[1, 2] = -x2**2 - x2 + 1 +# result[1, 0] = -t*x2 + (1-t) +# result[1, 1] = -t*x1 +# result[1, 2] = -x1*x2 - x1 + 1 +# +# return result + + +def track_path(args): + H = np.array([[1, 1, 1]]) + homotopy = homotopy_generator.HomotopyGenerator(H) + + tracker = path_tracker.PathTracker(homotopy, args.euler_step_size, args.euler_max_tries, + args.newton_max_iter, args.newton_convergence_threshold, args.sigma) + + ys_start, ys_prime, ys_hat_e, ys = [], [], [], [] + + try: + y = np.zeros(3) + for i in range(args.num_iterations): + y_start, y_prime, y_hat_e, y = tracker.transparent_step(y) + ys_start.append(y_start) + ys_prime.append(y_prime) + ys_hat_e.append(y_hat_e) + ys.append(y) + print(f"Iteration {i}: {y}") + except Exception as e: + print(f"Error: {e}") + + ys_start = np.array(ys_start) + ys_prime = np.array(ys_prime) + ys_hat_e = np.array(ys_hat_e) + ys = np.array(ys) + + df = pd.DataFrame({"x1b": ys_start[:, 0], + "x2b": ys_start[:, 1], + "tb": ys_start[:, 2], + "x1p": ys_prime[:, 0], + "x2p": ys_prime[:, 1], + "tp": ys_prime[:, 2], + "x1e": ys_hat_e[:, 0], + "x2e": ys_hat_e[:, 1], + "te": ys_hat_e[:, 2], + "x1n": ys[:, 0], + "x2n": ys[:, 1], + "tn": ys[:, 2] + }) + + if args.output: + df.to_csv(args.output, index=False) + else: + print(df) + + +def main(): + parser = argparse.ArgumentParser( + description='Homotopy continuation path tracker') + + parser.add_argument("--verbose", default=False, action='store_true') + parser.add_argument("--euler-step-size", type=float, + default=0.05, help="Step size for Euler predictor") + parser.add_argument("--euler-max-tries", type=int, default=5, + help="Maximum number of tries for Euler predictor") + parser.add_argument("--newton-max-iter", type=int, default=5, + help="Maximum number of iterations for Newton corrector") + parser.add_argument("--newton-convergence-threshold", type=float, + default=0.01, help="Convergence threshold for Newton corrector") + parser.add_argument("-s", "--sigma", type=int, default=1, + help="Direction in which the path is traced") + parser.add_argument("-o", "--output", type=str, help="Output csv file") + parser.add_argument("-n", "--num-iterations", type=int, default=20, + help="Number of iterations of the example program to run") + + args = parser.parse_args() + + track_path(args) + + +if __name__ == '__main__': + main() diff --git a/python/examples/toy_homotopy.py b/python/examples/toy_homotopy.py new file mode 100644 index 0000000..1e6925b --- /dev/null +++ b/python/examples/toy_homotopy.py @@ -0,0 +1,131 @@ +import argparse +import numpy as np +import pandas as pd + +# autopep8: off +import sys +import os +sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../") + +# TODO: How do I import PathTracker and HomotopyGenerator properly? +from hccd import path_tracker, homotopy_generator +# autopep8: on + + +class ToyHomotopy: + """Helper type implementing necessary functions for PathTracker. + + Toy example homotopy: + G = [[x1], + [x2]] + + F = [[x1 + x2 ], + [x2 + 0.5]] + + H = (1-t)*G + t*F + + Note that + y := [[x1], + [x2], + [t]] + """ + @staticmethod + def evaluate_H(y: np.ndarray) -> np.ndarray: + """Evaluate H at y.""" + x1 = y[0] + x2 = y[1] + t = y[2] + + result = np.zeros(shape=2) + result[0] = x1 + t * x2 + result[1] = x2 + t * 0.5 + + return result + + @staticmethod + def evaluate_DH(y: np.ndarray) -> np.ndarray: + """Evaluate Jacobian of H at y.""" + x1 = y[0] + x2 = y[1] + t = y[2] + + result = np.zeros(shape=(2, 3)) + result[0, 0] = 1 + result[0, 1] = t + result[0, 2] = x2 + result[1, 0] = 0 + result[1, 1] = 1 + result[1, 2] = 0.5 + + return result + + +def track_path(args): + tracker = path_tracker.PathTracker(ToyHomotopy, args.euler_step_size, args.euler_max_tries, + args.newton_max_iter, args.newton_convergence_threshold, args.sigma) + + ys_start, ys_prime, ys_hat_e, ys = [], [], [], [] + + try: + y = np.zeros(3) + for i in range(args.num_iterations): + y_start, y_prime, y_hat_e, y = tracker.transparent_step(y) + ys_start.append(y_start) + ys_prime.append(y_prime) + ys_hat_e.append(y_hat_e) + ys.append(y) + print(f"Iteration {i}: {y}") + except Exception as e: + print(f"Error: {e}") + + ys_start = np.array(ys_start) + ys_prime = np.array(ys_prime) + ys_hat_e = np.array(ys_hat_e) + ys = np.array(ys) + + df = pd.DataFrame({"x1b": ys_start[:, 0], + "x2b": ys_start[:, 1], + "tb": ys_start[:, 2], + "x1p": ys_prime[:, 0], + "x2p": ys_prime[:, 1], + "tp": ys_prime[:, 2], + "x1e": ys_hat_e[:, 0], + "x2e": ys_hat_e[:, 1], + "te": ys_hat_e[:, 2], + "x1n": ys[:, 0], + "x2n": ys[:, 1], + "tn": ys[:, 2] + }) + + if args.output: + df.to_csv(args.output, index=False) + else: + print(df) + + +def main(): + parser = argparse.ArgumentParser( + description='Homotopy continuation path tracker') + + parser.add_argument("--verbose", default=False, action='store_true') + parser.add_argument("--euler-step-size", type=float, + default=0.05, help="Step size for Euler predictor") + parser.add_argument("--euler-max-tries", type=int, default=5, + help="Maximum number of tries for Euler predictor") + parser.add_argument("--newton-max-iter", type=int, default=5, + help="Maximum number of iterations for Newton corrector") + parser.add_argument("--newton-convergence-threshold", type=float, + default=0.01, help="Convergence threshold for Newton corrector") + parser.add_argument("-s", "--sigma", type=int, default=1, + help="Direction in which the path is traced") + parser.add_argument("-o", "--output", type=str, help="Output csv file") + parser.add_argument("-n", "--num-iterations", type=int, default=20, + help="Number of iterations of the example program to run") + + args = parser.parse_args() + + track_path(args) + + +if __name__ == '__main__': + main() diff --git a/python/hccd/__init__.py b/python/hccd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/hccd/__main__.py b/python/hccd/__main__.py new file mode 100644 index 0000000..cd9ac48 --- /dev/null +++ b/python/hccd/__main__.py @@ -0,0 +1,6 @@ +def main(): + pass + + +if __name__ == "__main__": + main() diff --git a/python/hccd/homotopy_generator.py b/python/hccd/homotopy_generator.py new file mode 100644 index 0000000..79f6264 --- /dev/null +++ b/python/hccd/homotopy_generator.py @@ -0,0 +1,117 @@ +import numpy as np +import sympy as sp +from typing import List, Callable + + +class HomotopyGenerator: + """Generates homotopy functions from a binary parity check matrix.""" + + def __init__(self, parity_check_matrix: np.ndarray): + """ + Initialize with a parity check matrix. + + Args: + parity_check_matrix: Binary matrix where rows represent parity checks + and columns represent variables. + """ + self.H_matrix = parity_check_matrix + self.num_checks, self.num_vars = parity_check_matrix.shape + + # Create symbolic variables + self.x_vars = [sp.symbols(f'x{i+1}') for i in range(self.num_vars)] + self.t = sp.symbols('t') + + # Generate G, F, and H + self.G = self._create_G() + self.F = self._create_F() + self.H = self._create_H() + + # Convert to callable functions + self._H_lambda = self._create_H_lambda() + self._DH_lambda = self._create_DH_lambda() + + def _create_G(self) -> List[sp.Expr]: + """Create G polynomial system (the starting system).""" + # For each variable xi, add the polynomial [xi] + G = [] + for var in self.x_vars: + G.append(var) + + return G + + def _create_F(self) -> List[sp.Expr]: + """Create F polynomial system (the target system).""" + F = [] + + # Add 1 - xi^2 for each variable + for var in self.x_vars: + F.append(1 - var**2) + + # Add parity check polynomials: 1 - x1*x2*...*xk for each parity check + for row in self.H_matrix: + # Create product of variables that participate in this check + term = 1 + for i, bit in enumerate(row): + if bit == 1: + term *= self.x_vars[i] + + if term != 1: # Only add if there are variables in this check + F.append(1 - term) + + return F + + def _create_H(self) -> List[sp.Expr]: + """Create the homotopy H = (1-t)*G + t*F.""" + H = [] + + # Make sure G and F have the same length + # Repeat variables from G if needed to match F's length + G_extended = self.G.copy() + while len(G_extended) < len(self.F): + # Cycle through variables to repeat + for i in range(min(self.num_vars, len(self.F) - len(G_extended))): + G_extended.append(self.x_vars[i % self.num_vars]) + + # Create the homotopy + for g, f in zip(G_extended, self.F): + H.append((1 - self.t) * g + self.t * f) + + return H + + def _create_H_lambda(self) -> Callable: + """Create a lambda function to evaluate H.""" + all_vars = self.x_vars + [self.t] + return sp.lambdify(all_vars, self.H, 'numpy') + + def _create_DH_lambda(self) -> Callable: + """Create a lambda function to evaluate the Jacobian of H.""" + all_vars = self.x_vars + [self.t] + jacobian = sp.Matrix([[sp.diff(expr, var) + for var in all_vars] for expr in self.H]) + return sp.lambdify(all_vars, jacobian, 'numpy') + + def evaluate_H(self, y: np.ndarray) -> np.ndarray: + """ + Evaluate H at point y. + + Args: + y: Array of form [x1, x2, ..., xn, t] where xi are the variables + and t is the homotopy parameter. + + Returns: + Array containing H evaluated at y. + """ + return np.array(self._H_lambda(*y)) + + def evaluate_DH(self, y: np.ndarray) -> np.ndarray: + """ + Evaluate the Jacobian of H at point y. + + Args: + y: Array of form [x1, x2, ..., xn, t] where xi are the variables + and t is the homotopy parameter. + + Returns: + Matrix containing the Jacobian of H evaluated at y. + """ + return np.array(self._DH_lambda(*y), dtype=float) diff --git a/python/hccd/path_tracker.py b/python/hccd/path_tracker.py new file mode 100644 index 0000000..957e386 --- /dev/null +++ b/python/hccd/path_tracker.py @@ -0,0 +1,84 @@ +import numpy as np +import typing +import scipy + + +def _sign(val): + return -1 * (val < 0) + 1 * (val >= 0) + + +class PathTracker: + """ + Path trakcer for the homotopy continuation method. Uses a + predictor-corrector scheme to trace a path defined by a homotopy. + + References: + [1] T. Chen and T.-Y. Li, “Homotopy continuation method for solving + systems of nonlinear and polynomial equations,” Communications in + Information and Systems, vol. 15, no. 2, pp. 119–307, 2015 + """ + + def __init__(self, Homotopy, euler_step_size=0.05, euler_max_tries=10, newton_max_iter=5, + newton_convergence_threshold=0.001, sigma=1): + self.Homotopy = Homotopy + self._euler_step_size = euler_step_size + self._euler_max_tries = euler_max_tries + self._newton_max_iter = newton_max_iter + self._newton_convergence_threshold = newton_convergence_threshold + self._sigma = sigma + + def step(self, y): + """Perform one predictor-corrector step.""" + return self.transparent_step(y)[0] + + def transparent_step(self, y) -> typing.Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray,]: + """Perform one predictor-corrector step, returning intermediate results.""" + for i in range(self._euler_max_tries): + step_size = self._euler_step_size / (1 << i) + + y_hat, y_prime = self._perform_euler_predictor_step(y, step_size) + y_hat_n = self._perform_newtown_corrector_step(y_hat) + + return y, y_prime, y_hat, y_hat_n + + raise RuntimeError("Newton corrector did not converge") + + def _perform_euler_predictor_step(self, y, step_size) -> typing.Tuple[np.ndarray, np.ndarray]: + # Obtain y_prime + + DH = self.Homotopy.evaluate_DH(y) + ns = scipy.linalg.null_space(DH) + + y_prime = ns[:, 0] * self._sigma + + # Q, R = np.linalg.qr(np.transpose(DH), mode="complete") + # y_prime = Q[:, 2] + # + # if _sign(np.linalg.det(Q)*np.linalg.det(R[:2, :])) != _sign(self._sigma): + # y_prime = -y_prime + + # Perform prediction + + y_hat = y + step_size*y_prime + + return y_hat, y_prime + + def _perform_newtown_corrector_step(self, y) -> np.ndarray: + prev_y = y + + for _ in range(self._newton_max_iter): + # Perform correction + + DH = self.Homotopy.evaluate_DH(y) + DH_pinv = np.linalg.pinv(DH) + + y = y - DH_pinv @ self.Homotopy.evaluate_H(y) + + # Check stopping criterion + + if np.linalg.norm(y - prev_y) < self._newton_convergence_threshold: + return y + + prev_y = y + + raise RuntimeError("Newton corrector did not converge") diff --git a/scripts/plot_solution_curve.py b/scripts/plot_solution_curve.py new file mode 100644 index 0000000..2e65bf0 --- /dev/null +++ b/scripts/plot_solution_curve.py @@ -0,0 +1,106 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import argparse + + +def plot_2d(df: pd.DataFrame): + df = pd.read_csv("temp.csv") + + fig = plt.figure() + plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95) + ax = fig.add_subplot() + + # Analytically computed solution + ts = df['tb'] + x2s = -ts / 2 + x1s = -ts * x2s + + ax.scatter(x1s, x2s) + + ax.set_xlabel("x1") + ax.set_ylabel("x2") + # ax.set_zlabel("t") + + df = df.reset_index() + + for _, row in df.iterrows(): + x1b = row['x1b'] + x2b = row['x2b'] + # x1p = row['x1p'] + # x2p = row['x2p'] + x1e = row['x1e'] + x2e = row['x2e'] + x1n = row['x1n'] + x2n = row['x2n'] + + ax.plot([x1b, x1e], [x2b, x2e], 'b') + ax.plot([x1e, x1n], [x2e, x2n], 'r') + + plt.show() + + +def plot_3d(df: pd.DataFrame): + df = pd.read_csv("temp.csv") + + fig = plt.figure() + plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95) + ax = fig.add_subplot(projection='3d') + + # Analytically computed solution + ts = df['tb'] + x2s = -ts / 2 + x1s = -ts * x2s + + ax.scatter(x1s, x2s) + ax.scatter(x1s, x2s, ts) + + ax.set_xlabel("x1") + ax.set_ylabel("x2") + ax.set_zlabel("t") + + df = df.reset_index() + + for _, row in df.iterrows(): + x1b = row['x1b'] + x2b = row['x2b'] + tb = row['tb'] + # x1p = row['x1p'] + # x2p = row['x2p'] + # tp = row['tp'] + x1e = row['x1e'] + x2e = row['x2e'] + te = row['te'] + x1n = row['x1n'] + x2n = row['x2n'] + tn = row['tn'] + + ax.plot([x1b, x1e], [x2b, x2e], 'b', zs=[tb, te]) + ax.plot([x1e, x1n], [x2e, x2n], 'r', zs=[te, tn]) + + plt.show() + + +def main(): + # Parse command line arguments + + parser = argparse.ArgumentParser() + + parser.add_argument("--input", "-i", type=str, required=True, + default="temp.csv", help="Filename of the csv file") + parser.add_argument("--projection", "-p", type=str, required=False, + default="2d", help="2d or 3d plot") + + args = parser.parse_args() + + # Plot data + df = pd.read_csv(args.input) + + if args.projection == "2d": + plot_2d(df) + elif args.projection == "3d": + plot_3d(df) + + +if __name__ == "__main__": + main()