diff --git a/examples/cli.hpp b/examples/cli.hpp new file mode 100644 index 0000000..f298f5b --- /dev/null +++ b/examples/cli.hpp @@ -0,0 +1,264 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace examples +{ +namespace cli +{ + +inline std::string +normalize_option( std::string option ) +{ + if( option.rfind( "--", 0 ) != 0 ) + return option; + + const auto eq_pos = option.find( '=' ); + const auto end = eq_pos == std::string::npos ? option.size() : eq_pos; + std::replace( option.begin() + 2, option.begin() + static_cast( end ), '_', '-' ); + return option; +} + +inline bool +is_positional( std::string_view arg ) +{ + return !arg.empty() && arg.front() != '-'; +} + +inline int +parse_int( const std::string& label, const std::string& value ) +{ + int result = 0; + const char* begin = value.data(); + const char* end = begin + value.size(); + const auto [ptr, ec] = std::from_chars( begin, end, result ); + if( ec != std::errc() || ptr != end ) + throw std::invalid_argument( "Invalid value for " + label + ": '" + value + "'" ); + return result; +} + +class ArgParser +{ +public: + ArgParser( int argc, char** argv ) : argc_( argc ), argv_( argv ) {} + + bool + empty() const + { + return index_ >= argc_; + } + + std::string_view + peek() const + { + if( empty() ) + return std::string_view{}; + return argv_[index_]; + } + + std::string + take() + { + if( empty() ) + throw std::out_of_range( "No more arguments to consume" ); + return std::string( argv_[index_++] ); + } + + bool + consume_flag( std::string_view long_name, std::string_view short_name = {} ) + { + if( empty() ) + return false; + + std::string current = normalize_option( std::string( peek() ) ); + if( current == long_name || ( !short_name.empty() && current == short_name ) ) + { + ++index_; + return true; + } + return false; + } + + bool + consume_option( std::string_view long_name, std::string& value ) + { + if( empty() ) + return false; + + std::string current = normalize_option( std::string( peek() ) ); + const std::string prefix = std::string( long_name ) + "="; + if( current == long_name ) + { + ++index_; + if( empty() ) + throw std::invalid_argument( "Missing value for option '" + std::string( long_name ) + "'" ); + value = take(); + return true; + } + if( current.rfind( prefix, 0 ) == 0 ) + { + value = current.substr( prefix.size() ); + ++index_; + return true; + } + return false; + } + +private: + int argc_ = 0; + char** argv_ = nullptr; + int index_ = 1; +}; + +} // namespace cli +} // namespace examples + +namespace examples +{ +namespace cli +{ + +struct SolverOptions +{ + bool show_help = false; + std::string solver = "ilqr"; +}; + +inline SolverOptions +parse_solver_options( int argc, char** argv, std::string default_solver = "ilqr" ) +{ + SolverOptions options; + options.solver = std::move( default_solver ); + ArgParser args( argc, argv ); + + while( !args.empty() ) + { + const std::string raw_arg = std::string( args.peek() ); + if( args.consume_flag( "--help", "-h" ) ) + { + options.show_help = true; + continue; + } + + std::string value; + if( args.consume_option( "--solver", value ) ) + { + options.solver = value; + continue; + } + + throw std::invalid_argument( "Unknown argument '" + raw_arg + "'" ); + } + + return options; +} + +struct MultiAgentOptions +{ + bool show_help = false; + int agents = 10; + int max_outer = 10; + std::string solver = "ilqr"; + std::string strategy = "centralized"; +}; + +inline MultiAgentOptions +parse_multi_agent_options( int argc, char** argv, MultiAgentOptions defaults = {} ) +{ + MultiAgentOptions options = std::move( defaults ); + ArgParser args( argc, argv ); + bool positional_agents = false; + + while( !args.empty() ) + { + const std::string raw_arg = std::string( args.peek() ); + if( args.consume_flag( "--help", "-h" ) ) + { + options.show_help = true; + continue; + } + + std::string value; + if( args.consume_option( "--agents", value ) ) + { + options.agents = parse_int( "--agents", value ); + continue; + } + if( args.consume_option( "--solver", value ) ) + { + options.solver = value; + continue; + } + if( args.consume_option( "--strategy", value ) ) + { + options.strategy = value; + continue; + } + if( args.consume_option( "--max-outer", value ) ) + { + options.max_outer = parse_int( "--max-outer", value ); + continue; + } + + if( is_positional( raw_arg ) && !positional_agents ) + { + args.take(); + options.agents = parse_int( "agents", raw_arg ); + positional_agents = true; + continue; + } + + throw std::invalid_argument( "Unknown argument '" + raw_arg + "'" ); + } + + return options; +} + +struct RocketOptions +{ + bool show_help = false; + bool dump_traces = false; + std::string solver = "osqp"; +}; + +inline RocketOptions +parse_rocket_options( int argc, char** argv, RocketOptions defaults = {} ) +{ + RocketOptions options = std::move( defaults ); + ArgParser args( argc, argv ); + + while( !args.empty() ) + { + const std::string raw_arg = std::string( args.peek() ); + if( args.consume_flag( "--help", "-h" ) ) + { + options.show_help = true; + continue; + } + if( args.consume_flag( "--dump" ) ) + { + options.dump_traces = true; + continue; + } + + std::string value; + if( args.consume_option( "--solver", value ) ) + { + options.solver = value; + continue; + } + + throw std::invalid_argument( "Unknown argument '" + raw_arg + "'" ); + } + + return options; +} + +} // namespace cli +} // namespace examples + diff --git a/examples/multi_agent_lqr.cpp b/examples/multi_agent_lqr.cpp index 0cbad8e..3eb251f 100644 --- a/examples/multi_agent_lqr.cpp +++ b/examples/multi_agent_lqr.cpp @@ -1,7 +1,5 @@ #include -#include #include -#include #include #include #include @@ -16,6 +14,7 @@ #include "multi_agent_solver/strategies/strategy.hpp" #include "multi_agent_solver/types.hpp" +#include "cli.hpp" #include "example_utils.hpp" /*──────────────── create simple LQR OCP (unchanged) ───────────────*/ @@ -76,97 +75,11 @@ create_linear_lqr_ocp( int n_x, int n_u, double dt, int T ) return ocp; } -struct Options -{ - bool show_help = false; - int agents = 10; - int max_outer = 10; - std::string solver = "ilqr"; - std::string strategy = "centralized"; -}; +using Options = examples::cli::MultiAgentOptions; namespace { -int -parse_int( const std::string& label, const std::string& value ) -{ - int result = 0; - const char* begin = value.data(); - const char* end = begin + value.size(); - const auto [ptr, ec] = std::from_chars( begin, end, result ); - if( ec != std::errc() || ptr != end ) - throw std::invalid_argument( "Invalid value for " + label + ": '" + value + "'" ); - return result; -} - -Options -parse_options( int argc, char** argv ) -{ - Options options; - bool positional_agents = false; - for( int i = 1; i < argc; ++i ) - { - std::string arg = argv[i]; - if( arg.rfind( "--", 0 ) == 0 ) - { - const auto eq_pos = arg.find( '=' ); - const auto end = eq_pos == std::string::npos ? arg.size() : eq_pos; - std::replace( arg.begin() + 2, arg.begin() + static_cast( end ), '_', '-' ); - } - auto match_with_value = [&]( const std::string& name, std::string& out ) { - const std::string prefix = name + "="; - if( arg == name ) - { - if( i + 1 >= argc ) - throw std::invalid_argument( "Missing value for option '" + name + "'" ); - out = argv[++i]; - return true; - } - if( arg.rfind( prefix, 0 ) == 0 ) - { - out = arg.substr( prefix.size() ); - return true; - } - return false; - }; - - if( arg == "--help" || arg == "-h" ) - { - options.show_help = true; - continue; - } - - std::string value; - if( match_with_value( "--agents", value ) ) - { - options.agents = parse_int( "--agents", value ); - } - else if( match_with_value( "--solver", value ) ) - { - options.solver = value; - } - else if( match_with_value( "--strategy", value ) ) - { - options.strategy = value; - } - else if( match_with_value( "--max-outer", value ) ) - { - options.max_outer = parse_int( "--max-outer", value ); - } - else if( !arg.empty() && arg.front() != '-' && !positional_agents ) - { - options.agents = parse_int( "agents", arg ); - positional_agents = true; - } - else - { - throw std::invalid_argument( "Unknown argument '" + arg + "'" ); - } - } - return options; -} - void print_usage() { @@ -185,7 +98,7 @@ main( int argc, char** argv ) using namespace mas; try { - const Options options = parse_options( argc, argv ); + const Options options = examples::cli::parse_multi_agent_options( argc, argv ); if( options.show_help ) { print_usage(); diff --git a/examples/multi_agent_single_track.cpp b/examples/multi_agent_single_track.cpp index 3ee9159..32713d8 100644 --- a/examples/multi_agent_single_track.cpp +++ b/examples/multi_agent_single_track.cpp @@ -1,16 +1,15 @@ #include #include -#include #include #include #include #include #include #include -#include #include +#include "cli.hpp" #include "example_utils.hpp" #include "models/single_track_model.hpp" #include "multi_agent_solver/agent.hpp" @@ -53,97 +52,11 @@ create_single_track_circular_ocp( double initial_theta, double track_radius, dou return problem; } -struct Options -{ - bool show_help = false; - int agents = 10; - int max_outer = 10; - std::string solver = "ilqr"; - std::string strategy = "centralized"; -}; +using Options = examples::cli::MultiAgentOptions; namespace { -int -parse_int( const std::string& label, const std::string& value ) -{ - int result = 0; - const char* begin = value.data(); - const char* end = begin + value.size(); - const auto [ptr, ec] = std::from_chars( begin, end, result ); - if( ec != std::errc() || ptr != end ) - throw std::invalid_argument( "Invalid value for " + label + ": '" + value + "'" ); - return result; -} - -Options -parse_options( int argc, char** argv ) -{ - Options options; - bool positional_agents = false; - for( int i = 1; i < argc; ++i ) - { - std::string arg = argv[i]; - if( arg.rfind( "--", 0 ) == 0 ) - { - const auto eq_pos = arg.find( '=' ); - const auto end = eq_pos == std::string::npos ? arg.size() : eq_pos; - std::replace( arg.begin() + 2, arg.begin() + static_cast( end ), '_', '-' ); - } - auto match_with_value = [&]( const std::string& name, std::string& out ) { - const std::string prefix = name + "="; - if( arg == name ) - { - if( i + 1 >= argc ) - throw std::invalid_argument( "Missing value for option '" + name + "'" ); - out = argv[++i]; - return true; - } - if( arg.rfind( prefix, 0 ) == 0 ) - { - out = arg.substr( prefix.size() ); - return true; - } - return false; - }; - - if( arg == "--help" || arg == "-h" ) - { - options.show_help = true; - continue; - } - - std::string value; - if( match_with_value( "--agents", value ) ) - { - options.agents = parse_int( "--agents", value ); - } - else if( match_with_value( "--solver", value ) ) - { - options.solver = value; - } - else if( match_with_value( "--strategy", value ) ) - { - options.strategy = value; - } - else if( match_with_value( "--max-outer", value ) ) - { - options.max_outer = parse_int( "--max-outer", value ); - } - else if( !arg.empty() && arg.front() != '-' && !positional_agents ) - { - options.agents = parse_int( "agents", arg ); - positional_agents = true; - } - else - { - throw std::invalid_argument( "Unknown argument '" + arg + "'" ); - } - } - return options; -} - void print_usage() { @@ -161,7 +74,7 @@ main( int argc, char** argv ) using namespace mas; try { - const Options options = parse_options( argc, argv ); + const Options options = examples::cli::parse_multi_agent_options( argc, argv ); if( options.show_help ) { print_usage(); diff --git a/examples/pendulum_swing_up.cpp b/examples/pendulum_swing_up.cpp index 08e9a98..aebb00e 100644 --- a/examples/pendulum_swing_up.cpp +++ b/examples/pendulum_swing_up.cpp @@ -6,6 +6,7 @@ #include #include +#include "cli.hpp" #include "example_utils.hpp" #include "models/pendulum_model.hpp" #include "multi_agent_solver/ocp.hpp" @@ -88,58 +89,11 @@ create_pendulum_swingup_ocp() return problem; } -struct Options -{ - bool show_help = false; - std::string solver = "ilqr"; -}; +using Options = examples::cli::SolverOptions; namespace { -Options -parse_options( int argc, char** argv ) -{ - Options options; - for( int i = 1; i < argc; ++i ) - { - std::string arg = argv[i]; - auto match_with_value = [&]( const std::string& name, std::string& out ) { - const std::string prefix = name + "="; - if( arg == name ) - { - if( i + 1 >= argc ) - throw std::invalid_argument( "Missing value for option '" + name + "'" ); - out = argv[++i]; - return true; - } - if( arg.rfind( prefix, 0 ) == 0 ) - { - out = arg.substr( prefix.size() ); - return true; - } - return false; - }; - - if( arg == "--help" || arg == "-h" ) - { - options.show_help = true; - continue; - } - - std::string value; - if( match_with_value( "--solver", value ) ) - { - options.solver = value; - } - else - { - throw std::invalid_argument( "Unknown argument '" + arg + "'" ); - } - } - return options; -} - void print_usage() { @@ -156,7 +110,7 @@ main( int argc, char** argv ) using namespace mas; try { - const Options options = parse_options( argc, argv ); + const Options options = examples::cli::parse_solver_options( argc, argv ); if( options.show_help ) { print_usage(); diff --git a/examples/rocket_max_altitude.cpp b/examples/rocket_max_altitude.cpp index d176d05..38a7a94 100644 --- a/examples/rocket_max_altitude.cpp +++ b/examples/rocket_max_altitude.cpp @@ -2,6 +2,7 @@ #include #include +#include "cli.hpp" #include "example_utils.hpp" #include "models/rocket_model.hpp" #include "multi_agent_solver/ocp.hpp" @@ -113,51 +114,6 @@ create_max_altitude_rocket_ocp() return problem; } -struct Options -{ - bool show_help = false; - bool dump_traces = false; - std::string solver = "osqp"; -}; - -Options -parse_options( int argc, char** argv ) -{ - Options options; - for( int i = 1; i < argc; ++i ) - { - const std::string arg = argv[i]; - if( arg == "--help" || arg == "-h" ) - { - options.show_help = true; - continue; - } - if( arg == "--dump" ) - { - options.dump_traces = true; - continue; - } - - const std::string solver_prefix = "--solver="; - if( arg.rfind( solver_prefix, 0 ) == 0 ) - { - options.solver = arg.substr( solver_prefix.size() ); - continue; - } - - if( arg == "--solver" ) - { - if( i + 1 >= argc ) - throw std::invalid_argument( "Missing value for --solver" ); - options.solver = argv[++i]; - continue; - } - - throw std::invalid_argument( "Unknown argument '" + arg + "'" ); - } - return options; -} - void print_usage() { @@ -175,7 +131,7 @@ main( int argc, char** argv ) try { - const Options options = parse_options( argc, argv ); + const examples::cli::RocketOptions options = examples::cli::parse_rocket_options( argc, argv ); if( options.show_help ) { print_usage(); diff --git a/examples/single_track_ocp.cpp b/examples/single_track_ocp.cpp index 43e3168..3b924f4 100644 --- a/examples/single_track_ocp.cpp +++ b/examples/single_track_ocp.cpp @@ -4,6 +4,7 @@ #include #include +#include "cli.hpp" #include "example_utils.hpp" #include "models/single_track_model.hpp" #include "multi_agent_solver/ocp.hpp" @@ -114,58 +115,11 @@ create_single_track_lane_following_ocp() return problem; } -struct Options -{ - bool show_help = false; - std::string solver = "ilqr"; -}; +using Options = examples::cli::SolverOptions; namespace { -Options -parse_options( int argc, char** argv ) -{ - Options options; - for( int i = 1; i < argc; ++i ) - { - std::string arg = argv[i]; - auto match_with_value = [&]( const std::string& name, std::string& out ) { - const std::string prefix = name + "="; - if( arg == name ) - { - if( i + 1 >= argc ) - throw std::invalid_argument( "Missing value for option '" + name + "'" ); - out = argv[++i]; - return true; - } - if( arg.rfind( prefix, 0 ) == 0 ) - { - out = arg.substr( prefix.size() ); - return true; - } - return false; - }; - - if( arg == "--help" || arg == "-h" ) - { - options.show_help = true; - continue; - } - - std::string value; - if( match_with_value( "--solver", value ) ) - { - options.solver = value; - } - else - { - throw std::invalid_argument( "Unknown argument '" + arg + "'" ); - } - } - return options; -} - void print_usage() { @@ -182,7 +136,7 @@ main( int argc, char** argv ) using namespace mas; try { - const Options options = parse_options( argc, argv ); + const Options options = examples::cli::parse_solver_options( argc, argv ); if( options.show_help ) { print_usage();