From d1c470957b49380ec5ceba603dbd85a14f60f09b Mon Sep 17 00:00:00 2001
From: jdrouhard <john@jmdtech.org>
Date: Fri, 29 Oct 2021 07:45:01 -0500
Subject: [PATCH] feat(lsp): track pending+cancel requests on client object
 #15949

---
 runtime/doc/lsp.txt                          |  21 +++
 runtime/lua/vim/lsp.lua                      |  23 ++-
 runtime/lua/vim/lsp/rpc.lua                  |  29 +++-
 test/functional/fixtures/fake-lsp-server.lua |  49 +++++++
 test/functional/plugin/lsp_spec.lua          | 142 ++++++++++++++++++-
 5 files changed, 254 insertions(+), 10 deletions(-)

diff --git a/runtime/doc/lsp.txt b/runtime/doc/lsp.txt
index 5549d3c180..30e8ff658b 100644
--- a/runtime/doc/lsp.txt
+++ b/runtime/doc/lsp.txt
@@ -451,6 +451,22 @@ LspSignatureActiveParameter
    Used to highlight the active parameter in the signature help. See
    |vim.lsp.handlers.signature_help()|.
 
+==============================================================================
+EVENTS                                                            *lsp-events*
+
+LspProgressUpdate                                          *LspProgressUpdate*
+   Upon receipt of a progress notification from the server. See
+   |vim.lsp.util.get_progress_messages()|.
+
+LspRequest                                                        *LspRequest*
+   After a change to the active set of pending LSP requests. See {requests}
+   in |vim.lsp.client|.
+
+Example: >
+   autocmd User LspProgressUpdate redrawstatus
+   autocmd User LspRequest redrawstatus
+<
+
 ==============================================================================
 Lua module: vim.lsp                                                 *lsp-core*
 
@@ -608,6 +624,11 @@ client()                                                      *vim.lsp.client*
                     server.
                   • {handlers} (table): The handlers used by the client as
                     described in |lsp-handler|.
+                  • {requests} (table): The current pending requests in flight
+                    to the server. Entries are key-value pairs with the key
+                    being the request ID while the value is a table with `type`,
+                    `bufnr`, and `method` key-value pairs. `type` is either "pending"
+                    for an active request, or "cancel" for a cancel request.
                   • {config} (table): copy of the table that was passed by the
                     user to |vim.lsp.start_client()|.
                   • {server_capabilities} (table): Response from the server
diff --git a/runtime/lua/vim/lsp.lua b/runtime/lua/vim/lsp.lua
index 56ac1cbc66..3a067373d0 100644
--- a/runtime/lua/vim/lsp.lua
+++ b/runtime/lua/vim/lsp.lua
@@ -772,8 +772,10 @@ function lsp.start_client(config)
     attached_buffers = {};
 
     handlers = handlers;
+    requests = {};
+
     -- for $/progress report
-    messages = { name = name, messages = {}, progress = {}, status = {} }
+    messages = { name = name, messages = {}, progress = {}, status = {} };
   }
 
   -- Store the uninitialized_clients for cleanup in case we exit before initialize finishes.
@@ -906,11 +908,21 @@ function lsp.start_client(config)
     end
     -- Ensure pending didChange notifications are sent so that the server doesn't operate on a stale state
     changetracking.flush(client)
-
+    bufnr = resolve_bufnr(bufnr)
     local _ = log.debug() and log.debug(log_prefix, "client.request", client_id, method, params, handler, bufnr)
-    return rpc.request(method, params, function(err, result)
+    local success, request_id = rpc.request(method, params, function(err, result)
       handler(err, result, {method=method, client_id=client_id, bufnr=bufnr, params=params})
+    end, function(request_id)
+      client.requests[request_id] = nil
+      nvim_command("doautocmd <nomodeline> User LspRequest")
     end)
+
+    if success then
+      client.requests[request_id] = { type='pending', bufnr=bufnr, method=method }
+      nvim_command("doautocmd <nomodeline> User LspRequest")
+    end
+
+    return success, request_id
   end
 
   ---@private
@@ -970,6 +982,11 @@ function lsp.start_client(config)
   ---@see |vim.lsp.client.notify()|
   function client.cancel_request(id)
     validate{id = {id, 'n'}}
+    local request = client.requests[id]
+    if request and request.type == 'pending' then
+      request.type = 'cancel'
+      nvim_command("doautocmd <nomodeline> User LspRequest")
+    end
     return rpc.notify("$/cancelRequest", { id = id })
   end
 
diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua
index d9a684a738..bce1e9f35d 100644
--- a/runtime/lua/vim/lsp/rpc.lua
+++ b/runtime/lua/vim/lsp/rpc.lua
@@ -297,6 +297,7 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params)
 
   local message_index = 0
   local message_callbacks = {}
+  local notify_reply_callbacks = {}
 
   local handle, pid
   do
@@ -309,8 +310,9 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params)
       stdout:close()
       stderr:close()
       handle:close()
-      -- Make sure that message_callbacks can be gc'd.
+      -- Make sure that message_callbacks/notify_reply_callbacks can be gc'd.
       message_callbacks = nil
+      notify_reply_callbacks = nil
       dispatchers.on_exit(code, signal)
     end
     local spawn_params = {
@@ -375,10 +377,12 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params)
   ---@param method (string) The invoked LSP method
   ---@param params (table) Parameters for the invoked LSP method
   ---@param callback (function) Callback to invoke
+  ---@param notify_reply_callback (function) Callback to invoke as soon as a request is no longer pending
   ---@returns (bool, number) `(true, message_id)` if request could be sent, `false` if not
-  local function request(method, params, callback)
+  local function request(method, params, callback, notify_reply_callback)
     validate {
       callback = { callback, 'f' };
+      notify_reply_callback = { notify_reply_callback, 'f', true };
     }
     message_index = message_index + 1
     local message_id = message_index
@@ -388,8 +392,15 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params)
       method = method;
       params = params;
     }
-    if result and message_callbacks then
-      message_callbacks[message_id] = schedule_wrap(callback)
+    if result then
+      if message_callbacks then
+        message_callbacks[message_id] = schedule_wrap(callback)
+      else
+        return false
+      end
+      if notify_reply_callback and notify_reply_callbacks then
+        notify_reply_callbacks[message_id] = schedule_wrap(notify_reply_callback)
+      end
       return result, message_id
     else
       return false
@@ -466,6 +477,16 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params)
       -- We sent a number, so we expect a number.
       local result_id = tonumber(decoded.id)
 
+      -- Notify the user that a response was received for the request
+      local notify_reply_callback = notify_reply_callbacks and notify_reply_callbacks[result_id]
+      if notify_reply_callback then
+        validate {
+          notify_reply_callback = { notify_reply_callback, 'f' };
+        }
+        notify_reply_callback(result_id)
+        notify_reply_callbacks[result_id] = nil
+      end
+
       -- Do not surface RequestCancelled to users, it is RPC-internal.
       if decoded.error then
         local mute_error = false
diff --git a/test/functional/fixtures/fake-lsp-server.lua b/test/functional/fixtures/fake-lsp-server.lua
index 9abf478070..523e1e11fd 100644
--- a/test/functional/fixtures/fake-lsp-server.lua
+++ b/test/functional/fixtures/fake-lsp-server.lua
@@ -275,6 +275,55 @@ function tests.check_forward_content_modified()
   }
 end
 
+function tests.check_pending_request_tracked()
+  skeleton {
+    on_init = function(_)
+      return { capabilities = {} }
+    end;
+    body = function()
+        local msg = read_message()
+        assert_eq('slow_request', msg.method)
+        expect_notification('release')
+        respond(msg.id, nil, {})
+        expect_notification('finish')
+        notify('finish')
+    end;
+  }
+end
+
+function tests.check_cancel_request_tracked()
+  skeleton {
+    on_init = function(_)
+      return { capabilities = {} }
+    end;
+    body = function()
+        local msg = read_message()
+        assert_eq('slow_request', msg.method)
+        expect_notification('$/cancelRequest', {id=msg.id})
+        expect_notification('release')
+        respond(msg.id, {code = -32800}, nil)
+        notify('finish')
+    end;
+  }
+end
+
+function tests.check_tracked_requests_cleared()
+  skeleton {
+    on_init = function(_)
+      return { capabilities = {} }
+    end;
+    body = function()
+        local msg = read_message()
+        assert_eq('slow_request', msg.method)
+        expect_notification('$/cancelRequest', {id=msg.id})
+        expect_notification('release')
+        respond(msg.id, nil, {})
+        expect_notification('finish')
+        notify('finish')
+    end;
+  }
+end
+
 function tests.basic_finish()
   skeleton {
     on_init = function(params)
diff --git a/test/functional/plugin/lsp_spec.lua b/test/functional/plugin/lsp_spec.lua
index ce50abb50d..e89bfcefbf 100644
--- a/test/functional/plugin/lsp_spec.lua
+++ b/test/functional/plugin/lsp_spec.lua
@@ -3,9 +3,11 @@ local helpers = require('test.functional.helpers')(after_each)
 local assert_log = helpers.assert_log
 local clear = helpers.clear
 local buf_lines = helpers.buf_lines
+local command = helpers.command
 local dedent = helpers.dedent
 local exec_lua = helpers.exec_lua
 local eq = helpers.eq
+local eval = helpers.eval
 local matches = helpers.matches
 local pcall_err = helpers.pcall_err
 local pesc = helpers.pesc
@@ -272,7 +274,7 @@ describe('LSP', function()
         return
       end
       local expected_handlers = {
-        {NIL, {}, {method="shutdown", client_id=1}};
+        {NIL, {}, {method="shutdown", bufnr=1, client_id=1}};
         {NIL, {}, {method="test", client_id=1}};
       }
       test_rpc_server {
@@ -486,7 +488,7 @@ describe('LSP', function()
     it('should forward ContentModified to callback', function()
       local expected_handlers = {
         {NIL, {}, {method="finish", client_id=1}};
-        {{code = -32801}, NIL, {method = "error_code_test", client_id=1}};
+        {{code = -32801}, NIL, {method = "error_code_test", bufnr=1, client_id=1}};
       }
       local client
       test_rpc_server {
@@ -509,6 +511,140 @@ describe('LSP', function()
       }
     end)
 
+    it('should track pending requests to the language server', function()
+      local expected_handlers = {
+        {NIL, {}, {method="finish", client_id=1}};
+        {NIL, {}, {method="slow_request", bufnr=1, client_id=1}};
+      }
+      local client
+      test_rpc_server {
+        test_name = "check_pending_request_tracked";
+        on_init = function(_client)
+          client = _client
+          client.request("slow_request")
+          local request = exec_lua([=[ return TEST_RPC_CLIENT.requests[2] ]=])
+          eq("slow_request", request.method)
+          eq("pending", request.type)
+          client.notify("release")
+        end;
+        on_exit = function(code, signal)
+          eq(0, code, "exit code", fake_lsp_logfile)
+          eq(0, signal, "exit signal", fake_lsp_logfile)
+          eq(0, #expected_handlers, "did not call expected handler")
+        end;
+        on_handler = function(err, _, ctx)
+          eq(table.remove(expected_handlers), {err, {}, ctx}, "expected handler")
+          if ctx.method == 'slow_request' then
+            local request = exec_lua([=[ return TEST_RPC_CLIENT.requests[2] ]=])
+            eq(NIL, request)
+            client.notify("finish")
+          end
+          if ctx.method == 'finish' then client.stop() end
+        end;
+      }
+    end)
+
+    it('should track cancel requests to the language server', function()
+      local expected_handlers = {
+        {NIL, {}, {method="finish", client_id=1}};
+      }
+      local client
+      test_rpc_server {
+        test_name = "check_cancel_request_tracked";
+        on_init = function(_client)
+          client = _client
+          client.request("slow_request")
+          client.cancel_request(2)
+          local request = exec_lua([=[ return TEST_RPC_CLIENT.requests[2] ]=])
+          eq("slow_request", request.method)
+          eq("cancel", request.type)
+          client.notify("release")
+        end;
+        on_exit = function(code, signal)
+          eq(0, code, "exit code", fake_lsp_logfile)
+          eq(0, signal, "exit signal", fake_lsp_logfile)
+          eq(0, #expected_handlers, "did not call expected handler")
+        end;
+        on_handler = function(err, _, ctx)
+          eq(table.remove(expected_handlers), {err, {}, ctx}, "expected handler")
+          local request = exec_lua([=[ return TEST_RPC_CLIENT.requests[2] ]=])
+          eq(NIL, request)
+          if ctx.method == 'finish' then client.stop() end
+        end;
+      }
+    end)
+
+    it('should clear pending and cancel requests on reply', function()
+      local expected_handlers = {
+        {NIL, {}, {method="finish", client_id=1}};
+        {NIL, {}, {method="slow_request", bufnr=1, client_id=1}};
+      }
+      local client
+      test_rpc_server {
+        test_name = "check_tracked_requests_cleared";
+        on_init = function(_client)
+          client = _client
+          client.request("slow_request")
+          local request = exec_lua([=[ return TEST_RPC_CLIENT.requests[2] ]=])
+          eq("slow_request", request.method)
+          eq("pending", request.type)
+          client.cancel_request(2)
+          request = exec_lua([=[ return TEST_RPC_CLIENT.requests[2] ]=])
+          eq("slow_request", request.method)
+          eq("cancel", request.type)
+          client.notify("release")
+        end;
+        on_exit = function(code, signal)
+          eq(0, code, "exit code", fake_lsp_logfile)
+          eq(0, signal, "exit signal", fake_lsp_logfile)
+          eq(0, #expected_handlers, "did not call expected handler")
+        end;
+        on_handler = function(err, _, ctx)
+          eq(table.remove(expected_handlers), {err, {}, ctx}, "expected handler")
+          if ctx.method == 'slow_request' then
+            local request = exec_lua([=[ return TEST_RPC_CLIENT.requests[2] ]=])
+            eq(NIL, request)
+            client.notify("finish")
+          end
+          if ctx.method == 'finish' then client.stop() end
+        end;
+      }
+    end)
+
+    it('should trigger LspRequest autocmd when requests table changes', function()
+      local expected_handlers = {
+        {NIL, {}, {method="finish", client_id=1}};
+        {NIL, {}, {method="slow_request", bufnr=1, client_id=1}};
+      }
+      local client
+      test_rpc_server {
+        test_name = "check_tracked_requests_cleared";
+        on_init = function(_client)
+          command('let g:requests = 0')
+          command('autocmd User LspRequest let g:requests+=1')
+          client = _client
+          client.request("slow_request")
+          eq(1, eval('g:requests'))
+          client.cancel_request(2)
+          eq(2, eval('g:requests'))
+          client.notify("release")
+        end;
+        on_exit = function(code, signal)
+          eq(0, code, "exit code", fake_lsp_logfile)
+          eq(0, signal, "exit signal", fake_lsp_logfile)
+          eq(0, #expected_handlers, "did not call expected handler")
+          eq(3, eval('g:requests'))
+        end;
+        on_handler = function(err, _, ctx)
+          eq(table.remove(expected_handlers), {err, {}, ctx}, "expected handler")
+          if ctx.method == 'slow_request' then
+            client.notify("finish")
+          end
+          if ctx.method == 'finish' then client.stop() end
+        end;
+      }
+    end)
+
     it('should not send didOpen if the buffer closes before init', function()
       local expected_handlers = {
         {NIL, {}, {method="shutdown", client_id=1}};
@@ -790,7 +926,7 @@ describe('LSP', function()
     -- TODO(askhan) we don't support full for now, so we can disable these tests.
     pending('should check the body and didChange incremental normal mode editing', function()
       local expected_handlers = {
-        {NIL, {}, {method="shutdown", client_id=1}};
+        {NIL, {}, {method="shutdown", bufnr=1, client_id=1}};
         {NIL, {}, {method="finish", client_id=1}};
         {NIL, {}, {method="start", client_id=1}};
       }