diff --git a/lib/gen_lsp.ex b/lib/gen_lsp.ex index 094c73c..c5d232b 100644 --- a/lib/gen_lsp.ex +++ b/lib/gen_lsp.ex @@ -83,6 +83,8 @@ defmodule GenLSP do end end + @optional_callbacks handle_continue: 2 + require Logger @doc """ @@ -101,7 +103,8 @@ defmodule GenLSP do end ``` """ - @callback init(lsp :: GenLSP.LSP.t(), init_arg :: term()) :: {:ok, GenLSP.LSP.t()} + @callback init(lsp :: GenLSP.LSP.t(), init_arg :: term()) :: + {:ok, GenLSP.LSP.t()} | {:ok, GenLSP.LSP.t(), {:continue, term()}} @doc """ The callback responsible for handling requests from the client. @@ -127,7 +130,10 @@ defmodule GenLSP do ``` """ @callback handle_request(request :: term(), state) :: - {:reply, reply :: term(), state} | {:noreply, state} + {:reply, reply :: term(), state} + | {:reply, reply :: term(), state, {:continue, term()}} + | {:noreply, state} + | {:noreply, state, {:continue, term()}} when state: GenLSP.LSP.t() @doc """ The callback responsible for handling notifications from the client. @@ -145,7 +151,8 @@ defmodule GenLSP do end ``` """ - @callback handle_notification(notification :: term(), state) :: {:noreply, state} + @callback handle_notification(notification :: term(), state) :: + {:noreply, state} | {:noreply, state, {:continue, term()}} when state: GenLSP.LSP.t() @doc """ The callback responsible for handling normal messages. @@ -163,7 +170,20 @@ defmodule GenLSP do end ``` """ - @callback handle_info(message :: any(), state) :: {:noreply, state} when state: GenLSP.LSP.t() + @callback handle_info(message :: any(), state) :: + {:noreply, state} | {:noreply, state, {:continue, term()}} + when state: GenLSP.LSP.t() + + @doc """ + Invoked to handle continue instructions, workalike to `GenServer.handle_continue/2`. + + The continue callback is guaranteed to execute after the initiating reply was written to + the wire. This ensures the client receives the response before any side effects from + the continue callback. + """ + @callback handle_continue(continue_arg :: term(), state) :: + {:noreply, state} | {:noreply, state, {:continue, term()}} + when state: GenLSP.LSP.t() @options_schema NimbleOptions.new!( buffer: [ @@ -217,16 +237,21 @@ defmodule GenLSP do tasks: Map.new() } - case module.init(lsp, init_args) do - {:ok, %LSP{} = lsp} -> - deb = :sys.debug_options([]) - if opts[:name], do: Process.register(self(), opts[:name]) - :proc_lib.init_ack(parent, {:ok, me}) + {lsp, continue_arg} = + case module.init(lsp, init_args) do + {:ok, %LSP{} = lsp} -> {lsp, nil} + {:ok, %LSP{} = lsp, {:continue, continue_arg}} -> {lsp, continue_arg} + end - GenLSP.Buffer.listen(buffer, me) + deb = :sys.debug_options([]) + if opts[:name], do: Process.register(self(), opts[:name]) + :proc_lib.init_ack(parent, {:ok, me}) - loop(lsp, parent, deb) - end + GenLSP.Buffer.listen(buffer, me) + + lsp = if continue_arg, do: execute_continue(lsp, continue_arg), else: lsp + + loop(lsp, parent, deb) end @doc """ @@ -369,70 +394,85 @@ defmodule GenLSP do {lsp.mod.handle_request(req, lsp), %{}} end) - case result do - {:reply, reply, %LSP{} = lsp} -> - response_key = - case reply do - %GenLSP.ErrorResponse{} -> "error" - _ -> "result" - end - - # if result is valid, continue, if not, we return an internal error - {response_key, response} = - case Schematic.dump(req.__struct__.result(), reply) do - {:ok, output} -> - {response_key, output} + # Normalize result to extract continue_arg + {result, continue_arg} = + case result do + {:reply, reply, %LSP{} = lsp} -> + {{:reply, reply, lsp}, nil} - {:error, errors} -> - exception = InvalidResponse.exception({req.method, reply, errors}) + {:reply, reply, %LSP{} = lsp, {:continue, arg}} -> + {{:reply, reply, lsp}, arg} - Logger.error(Exception.format(:error, exception)) + {:noreply, %LSP{} = lsp} -> + {{:noreply, lsp}, nil} - {:ok, output} = - Schematic.dump( - GenLSP.ErrorResponse.schema(), - %GenLSP.ErrorResponse{ - code: GenLSP.Enumerations.ErrorCodes.internal_error(), - message: exception.message - } - ) + {:noreply, %LSP{} = lsp, {:continue, arg}} -> + {{:noreply, lsp}, arg} + end - {"error", output} + lsp = + case result do + {:reply, reply, %LSP{} = lsp} -> + response_key = + case reply do + %GenLSP.ErrorResponse{} -> "error" + _ -> "result" + end + + # if result is valid, continue, if not, we return an internal error + {response_key, response} = + case Schematic.dump(req.__struct__.result(), reply) do + {:ok, output} -> + {response_key, output} + + {:error, errors} -> + exception = InvalidResponse.exception({req.method, reply, errors}) + + Logger.error(Exception.format(:error, exception)) + + {:ok, output} = + Schematic.dump( + GenLSP.ErrorResponse.schema(), + %GenLSP.ErrorResponse{ + code: GenLSP.Enumerations.ErrorCodes.internal_error(), + message: exception.message + } + ) + + {"error", output} + end + + packet = %{ + "jsonrpc" => "2.0", + "id" => id, + response_key => response + } + + if continue_arg do + GenLSP.Buffer.outgoing_flush(lsp.buffer, packet) + else + GenLSP.Buffer.outgoing(lsp.buffer, packet) end - packet = %{ - "jsonrpc" => "2.0", - "id" => id, - response_key => response - } - - GenLSP.Buffer.outgoing(lsp.buffer, packet) - - duration = System.system_time(:microsecond) - start + lsp - Logger.debug( - "handled request client -> server #{req.method} in #{format_time(duration)}", - id: req.id, - method: req.method - ) + {:noreply, %LSP{} = lsp} -> + lsp + end - :telemetry.execute([:gen_lsp, :request, :client, :stop], %{ - duration: duration - }) + duration = System.system_time(:microsecond) - start - {:noreply, _lsp} -> - duration = System.system_time(:microsecond) - start + Logger.debug( + "handled request client -> server #{req.method} in #{format_time(duration)}", + id: req.id, + method: req.method + ) - Logger.debug( - "handled request client -> server #{req.method} in #{format_time(duration)}", - id: req.id, - method: req.method - ) + :telemetry.execute([:gen_lsp, :request, :client, :stop], %{ + duration: duration + }) - :telemetry.execute([:gen_lsp, :request, :client, :stop], %{ - duration: duration - }) - end + if continue_arg, do: execute_continue(lsp, continue_arg) {:error, errors} -> # the payload is not parseable at all, other than being valid JSON and having @@ -520,19 +560,24 @@ defmodule GenLSP do end ) - case result do - {:noreply, %LSP{}} -> - duration = System.system_time(:microsecond) - start + {lsp, continue_arg} = + case result do + {:noreply, %LSP{} = lsp} -> {lsp, nil} + {:noreply, %LSP{} = lsp, {:continue, arg}} -> {lsp, arg} + end - Logger.debug( - "handled notification client -> server #{note.method} in #{format_time(duration)}", - method: note.method - ) + duration = System.system_time(:microsecond) - start - :telemetry.execute([:gen_lsp, :notification, :client, :stop], %{ - duration: duration - }) - end + Logger.debug( + "handled notification client -> server #{note.method} in #{format_time(duration)}", + method: note.method + ) + + :telemetry.execute([:gen_lsp, :notification, :client, :stop], %{ + duration: duration + }) + + if continue_arg, do: execute_continue(lsp, continue_arg) {:error, errors} -> # the payload is not parseable at all, other than being valid JSON @@ -564,11 +609,16 @@ defmodule GenLSP do {lsp.mod.handle_info(message, lsp), %{}} end) - case result do - {:noreply, %LSP{} = _lsp} -> - duration = System.system_time(:microsecond) - start - :telemetry.execute([:gen_lsp, :info, :stop], %{duration: duration}) - end + {lsp, continue_arg} = + case result do + {:noreply, %LSP{} = lsp} -> {lsp, nil} + {:noreply, %LSP{} = lsp, {:continue, arg}} -> {lsp, arg} + end + + duration = System.system_time(:microsecond) - start + :telemetry.execute([:gen_lsp, :info, :stop], %{duration: duration}) + + if continue_arg, do: execute_continue(lsp, continue_arg) end ) @@ -603,6 +653,21 @@ defmodule GenLSP do end) end + defp execute_continue(lsp, continue_arg) do + result = + :telemetry.span([:gen_lsp, :handle_continue], %{continue: continue_arg}, fn -> + {lsp.mod.handle_continue(continue_arg, lsp), %{}} + end) + + case result do + {:noreply, %LSP{} = lsp} -> + lsp + + {:noreply, %LSP{} = lsp, {:continue, next_continue}} -> + execute_continue(lsp, next_continue) + end + end + defp dump!(schematic, structure) do {:ok, output} = Schematic.dump(schematic, structure) output diff --git a/lib/gen_lsp/buffer.ex b/lib/gen_lsp/buffer.ex index d22c077..d7dff40 100644 --- a/lib/gen_lsp/buffer.ex +++ b/lib/gen_lsp/buffer.ex @@ -52,6 +52,11 @@ defmodule GenLSP.Buffer do GenServer.call(server, {:outgoing_sync, packet}, timeout) end + @doc false + def outgoing_flush(server, packet) do + GenServer.call(server, {:outgoing_flush, packet}) + end + @doc false def comm_state(server) do GenServer.call(server, :comm_state) @@ -79,6 +84,15 @@ defmodule GenLSP.Buffer do {:noreply, %{state | awaiting_response: Map.put(state.awaiting_response, id, from)}} end + def handle_call({:outgoing_flush, packet}, _from, state) do + :telemetry.span([:gen_lsp, :buffer, :outgoing], %{kind: :flush}, fn -> + :ok = state.comm.write(Jason.encode!(packet), state.comm_data) + {:ok, %{}} + end) + + {:reply, :ok, state} + end + @doc false def handle_cast({:incoming, packet}, %{lsp: lsp} = state) do state = diff --git a/mix.lock b/mix.lock index e460ab2..e2c4612 100644 --- a/mix.lock +++ b/mix.lock @@ -1,14 +1,14 @@ %{ "dialyxir": {:hex, :dialyxir, "1.3.0", "fd1672f0922b7648ff9ce7b1b26fcf0ef56dda964a459892ad15f6b4410b5284", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "00b2a4bcd6aa8db9dcb0b38c1225b7277dca9bc370b6438715667071a304696f"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.26", "f4291134583f373c7d8755566122908eb9662df4c4b63caa66a0eabe06569b0a", [:mix], [], "hexpm", "48d460899f8a0c52c5470676611c01f64f3337bad0b26ddab43648428d94aabc"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.44", "f20830dd6b5c77afe2b063777ddbbff09f9759396500cdbe7523efd58d7a339c", [:mix], [], "hexpm", "4778ac752b4701a5599215f7030989c989ffdc4f6df457c5f36938cc2d2a2750"}, "erlex": {:hex, :erlex, "0.2.6", "c7987d15e899c7a2f34f5420d2a2ea0d659682c06ac607572df55a43753aa12e", [:mix], [], "hexpm", "2ed2e25711feb44d52b17d2780eabf998452f6efda104877a3881c2f8c0c0c75"}, - "ex_doc": {:hex, :ex_doc, "0.28.4", "001a0ea6beac2f810f1abc3dbf4b123e9593eaa5f00dd13ded024eae7c523298", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "bf85d003dd34911d89c8ddb8bda1a958af3471a274a4c2150a9c01c78ac3f8ed"}, + "ex_doc": {:hex, :ex_doc, "0.39.3", "519c6bc7e84a2918b737aec7ef48b96aa4698342927d080437f61395d361dcee", [:mix], [{:earmark_parser, "~> 1.4.44", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "0590955cf7ad3b625780ee1c1ea627c28a78948c6c0a9b0322bd976a079996e1"}, "jason": {:hex, :jason, "1.3.0", "fa6b82a934feb176263ad2df0dbd91bf633d4a46ebfdffea0c8ae82953714946", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "53fc1f51255390e0ec7e50f9cb41e751c260d065dcba2bf0d08dc51a4002c2ac"}, - "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, - "makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"}, - "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, + "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, + "makeup_elixir": {:hex, :makeup_elixir, "1.0.1", "e928a4f984e795e41e3abd27bfc09f51db16ab8ba1aebdba2b3a575437efafc2", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "7284900d412a3e5cfd97fdaed4f5ed389b8f2b4cb49efc0eb3bd10e2febf9507"}, + "makeup_erlang": {:hex, :makeup_erlang, "1.0.2", "03e1804074b3aa64d5fad7aa64601ed0fb395337b982d9bcf04029d68d51b6a7", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "af33ff7ef368d5893e4a267933e7744e46ce3cf1f61e2dccf53a111ed3aa3727"}, "nimble_options": {:hex, :nimble_options, "1.0.1", "b448018287b22584e91b5fd9c6c0ad717cb4bcdaa457957c8d57770f56625c43", [:mix], [], "hexpm", "078b2927cd9f84555be6386d56e849b0c555025ecccf7afee00ab6a9e6f63837"}, - "nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"}, + "nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"}, "schematic": {:hex, :schematic, "0.2.1", "0b091df94146fd15a0a343d1bd179a6c5a58562527746dadd09477311698dbb1", [:mix], [{:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "0b255d65921e38006138201cd4263fd8bb807d9dfc511074615cd264a571b3b1"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "typed_struct": {:hex, :typed_struct, "0.3.0", "939789e3c1dca39d7170c87f729127469d1315dcf99fee8e152bb774b17e7ff7", [:mix], [], "hexpm", "c50bd5c3a61fe4e198a8504f939be3d3c85903b382bde4865579bc23111d1b6d"}, diff --git a/test/gen_lsp_test.exs b/test/gen_lsp_test.exs index 13a2e82..8cdb451 100644 --- a/test/gen_lsp_test.exs +++ b/test/gen_lsp_test.exs @@ -19,7 +19,7 @@ defmodule GenLSPTest do assert %{foo: :bar, test_pid: self()} == :sys.get_state(server.assigns) end - test "can receive and reply to a request", %{client: client} do + test "can receive and reply to a request, with handle_continue", %{client: client} do id = System.unique_integer([:positive]) assert :ok == @@ -30,6 +30,7 @@ defmodule GenLSPTest do "id" => id }) + # Response arrives first assert_result ^id, %{ "capabilities" => %{ @@ -39,6 +40,9 @@ defmodule GenLSPTest do "serverInfo" => %{"name" => "Test LSP"} }, 500 + + # Then handle_continue fires (after response is flushed to wire) + assert_notification "window/logMessage", %{"message" => "post_init continue executed"}, 500 end test "accepts a string id", %{client: client} do @@ -273,6 +277,13 @@ defmodule GenLSPTest do assert_receive {:info, :ack} end + test "handle_continue chains in sequence", %{server: server} do + send(server.lsp, :trigger_chain) + + assert_receive {:continue, :chain_step1}, 500 + assert_receive {:continue, :chain_step2}, 500 + end + test "can respond with an error", %{client: client} do id = System.unique_integer([:positive]) diff --git a/test/support/example_server.ex b/test/support/example_server.ex index c6936bc..ac42ff7 100644 --- a/test/support/example_server.ex +++ b/test/support/example_server.ex @@ -24,7 +24,7 @@ defmodule GenLSPTest.ExampleServer do call_hierarchy_provider: %CallHierarchyOptions{work_done_progress: true} }, server_info: %{name: "Test LSP"} - }, lsp} + }, lsp, {:continue, :post_init}} end def handle_request(%Requests.TextDocumentFormatting{}, lsp) do @@ -54,6 +54,23 @@ defmodule GenLSPTest.ExampleServer do }, lsp} end + @impl true + def handle_continue(:post_init, lsp) do + GenLSP.log(lsp, "post_init continue executed") + {:noreply, lsp} + end + + def handle_continue(:chain_step1, lsp) do + send(assigns(lsp).test_pid, {:continue, :chain_step1}) + {:noreply, lsp, {:continue, :chain_step2}} + end + + def handle_continue(:chain_step2, lsp) do + send(assigns(lsp).test_pid, {:continue, :chain_step2}) + {:noreply, lsp} + end + + @impl true def handle_notification(%Notifications.Initialized{}, lsp) do GenLSP.request(lsp, %GenLSP.Requests.ClientRegisterCapability{ id: System.unique_integer([:positive]), @@ -135,6 +152,10 @@ defmodule GenLSPTest.ExampleServer do {:noreply, lsp} end + def handle_info(:trigger_chain, lsp) do + {:noreply, lsp, {:continue, :chain_step1}} + end + def handle_info(_message, lsp) do send(assigns(lsp).test_pid, {:info, :ack}) {:noreply, lsp}