diff --git a/src/optimize.jl b/src/optimize.jl index 10627778..3648a7ce 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -31,9 +31,8 @@ result = optimize(problem; method=Krotov) If `check` is true (default), the `initial_state` and `generator` of each trajectory is checked with [`check_state`](@ref) and [`check_generator`](@ref). -Any other keyword argument temporarily overrides the corresponding keyword -argument in [`problem`](@ref ControlProblem). These arguments are available to -the optimizer, see each optimization package's documentation for details. +Additional keyword arguments can be passed to these two `check` routines by +passing a named-tuple as `check_state_kwargs` and `check_generator_kwargs`. The `callback` can be given as a function to be called after each iteration in order to analyze the progress of the optimization or to modify the state of @@ -54,7 +53,8 @@ the method-specific [`make_print_iters`](@ref) to print the progress of the optimization after each iteration. This automatic callback runs after any manually given `callback`. -All remaining keyword argument are method-specific. +All remaining keyword argument are method-specific and temporarily overrides +the corresponding keyword argument in [`problem`](@ref ControlProblem). To obtain the documentation for which options a particular method uses, run, e.g., @@ -75,10 +75,8 @@ function optimize( check = get(problem.kwargs, :check, true), print_iters = get(problem.kwargs, :print_iters, true), callback = get(problem.kwargs, :callback, nothing), - for_expval = true, # undocumented - for_pwc = true, # undocumented - for_time_continuous = false, # undocumented - for_parameterization = false, # undocumented + check_state_kwargs = get(problem.kwargs, :check_state_kwargs, (;)), + check_generator_kwargs = get(problem.kwargs, :check_generator_kwargs, (;)), kwargs... ) @@ -113,20 +111,16 @@ function optimize( ) if check - # TODO: checks will have to be method-dependent, and then we may not - # need all the `for_...` keyword arguments + # TODO: checks maybe should be method-dependent for (i, traj) in enumerate(problem.trajectories) - if !check_state(traj.initial_state) + if !check_state(traj.initial_state; check_state_kwargs...) error("The `initial_state` of trajectory $i is not valid") end if !check_generator( traj.generator; + check_generator_kwargs..., state = traj.initial_state, tlist = problem.tlist, - for_expval, - for_pwc, - for_time_continuous, - for_parameterization, ) error("The `generator` of trajectory $i is not valid") end diff --git a/test/test_optimize_or_load.jl b/test/test_optimize_or_load.jl index bd097eca..e28fbccb 100644 --- a/test/test_optimize_or_load.jl +++ b/test/test_optimize_or_load.jl @@ -5,6 +5,44 @@ using QuantumControl: @optimize_or_load, load_optimization, save_optimization, o using QuantumControlTestUtils.DummyOptimization: dummy_control_problem, DummyOptimizationResult + +@testset "check_state_kwargs and check_generator_kwargs" begin + + problem = dummy_control_problem() + outdir = mktempdir() + outfile = joinpath(outdir, "optimization_check_kwargs.jld2") + + # Smoke test: extra kwargs are accepted + captured = IOCapture.capture(passthrough = false) do + @optimize_or_load( + outfile, + problem; + method = :dummymethod, + force = true, + check_state_kwargs = (; atol = 1e-10), + check_generator_kwargs = (; atol = 1e-10), + ) + end + @test captured.value isa DummyOptimizationResult + @test captured.value.converged + + # Test with teeth: for_time_continuous=true fails check_generator for the + # pwc dummy controls (array-based controls cannot be evaluated at a + # continuous time t) + captured2 = IOCapture.capture(rethrow = Union{}, passthrough = false) do + @optimize_or_load( + outfile, + problem; + method = :dummymethod, + force = true, + check_generator_kwargs = (; for_time_continuous = true), + ) + end + @test captured2.value isa Exception + +end + + @testset "metadata" begin problem = dummy_control_problem()