This commit is contained in:
Mathias Fußenegger 2024-12-18 14:40:43 +00:00 committed by GitHub
commit 8fc2951efc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,34 +16,21 @@ local function format_message_with_content_length(message)
}) })
end end
---@class (private) vim.lsp.rpc.Headers: {string: any} --- Extract content-length from the msg header
---@field content_length integer
--- Parses an LSP Message's header
--- ---
---@param header string The header to parse. ---@param header string The header to parse
---@return vim.lsp.rpc.Headers#parsed headers ---@return integer?
local function parse_headers(header) local function get_content_length(header)
assert(type(header) == 'string', 'header must be a string') for line in header:gmatch('(.-)\r\n') do
--- @type vim.lsp.rpc.Headers
local headers = {}
for line in vim.gsplit(header, '\r\n', { plain = true }) do
if line == '' then if line == '' then
break break
end end
--- @type string?, string? local key, value = line:match('^%s*(%S+)%s*:%s*(%d+)%s*$')
local key, value = line:match('^%s*(%S+)%s*:%s*(.+)%s*$') if key and key:lower() == 'content-length' then
if key then return tonumber(value)
key = key:lower():gsub('%-', '_') --- @type string
headers[key] = value
else
log.error('invalid header line %q', line)
error(string.format('invalid header line %q', line))
end end
end end
headers.content_length = tonumber(headers.content_length) error('Content-Length not found in header: ' .. header)
or error(string.format('Content-Length not found in headers. %q', header))
return headers
end end
-- This is the start of any possible header patterns. The gsub converts it to a -- This is the start of any possible header patterns. The gsub converts it to a
@ -52,70 +39,102 @@ local header_start_pattern = ('content'):gsub('%w', function(c)
return '[' .. c .. c:upper() .. ']' return '[' .. c .. c:upper() .. ']'
end) end)
local has_strbuffer, strbuffer = pcall(require, "string.buffer")
--- The actual workhorse. --- The actual workhorse.
local function request_parser_loop() ---@type function
local buffer = '' -- only for header part local request_parser_loop
while true do
-- A message can only be complete if it has a double CRLF and also the full if has_strbuffer then
-- payload, so first let's check for the CRLFs request_parser_loop = function()
local start, finish = buffer:find('\r\n\r\n', 1, true) local buf = strbuffer.new()
-- Start parsing the headers while true do
if start then local msg = buf:tostring()
-- This is a workaround for servers sending initial garbage before local header_end = msg:find('\r\n\r\n', 1, true)
-- sending headers, such as if a bash script sends stdout. It assumes if header_end then
-- that we know all of the headers ahead of time. At this moment, the local header = buf:get(header_end + 1)
-- only valid headers start with "Content-*", so that's the thing we will buf:skip(2) -- skip past header boundary
-- be searching for. local content_length = get_content_length(header)
-- TODO(ashkan) I'd like to remove this, but it seems permanent :( while #buf < content_length do
local buffer_start = buffer:find(header_start_pattern) local chunk = coroutine.yield()
if not buffer_start then buf:put(chunk)
error( end
string.format( local body = buf:get(content_length)
"Headers were expected, a different response was received. The server response was '%s'.", local chunk = coroutine.yield(body)
buffer buf:put(chunk)
) else
)
end
local headers = parse_headers(buffer:sub(buffer_start, start - 1))
local content_length = headers.content_length
-- Use table instead of just string to buffer the message. It prevents
-- a ton of strings allocating.
-- ref. http://www.lua.org/pil/11.6.html
---@type string[]
local body_chunks = { buffer:sub(finish + 1) }
local body_length = #body_chunks[1]
-- Keep waiting for data until we have enough.
while body_length < content_length do
---@type string
local chunk = coroutine.yield() local chunk = coroutine.yield()
or error('Expected more data for the body. The server may have died.') -- TODO hmm. buf:put(chunk)
table.insert(body_chunks, chunk)
body_length = body_length + #chunk
end end
local last_chunk = body_chunks[#body_chunks] end
end
else
request_parser_loop = function()
local buffer = '' -- only for header part
while true do
-- A message can only be complete if it has a double CRLF and also the full
-- payload, so first let's check for the CRLFs
local header_end, body_start = buffer:find('\r\n\r\n', 1, true)
-- Start parsing the headers
if header_end then
-- This is a workaround for servers sending initial garbage before
-- sending headers, such as if a bash script sends stdout. It assumes
-- that we know all of the headers ahead of time. At this moment, the
-- only valid headers start with "Content-*", so that's the thing we will
-- be searching for.
-- TODO(ashkan) I'd like to remove this, but it seems permanent :(
local buffer_start = buffer:find(header_start_pattern)
if not buffer_start then
error(
string.format(
"Headers were expected, a different response was received. The server response was '%s'.",
buffer
)
)
end
local header = buffer:sub(buffer_start, header_end + 1)
local content_length = get_content_length(header)
-- Use table instead of just string to buffer the message. It prevents
-- a ton of strings allocating.
-- ref. http://www.lua.org/pil/11.6.html
---@type string[]
local body_chunks = { buffer:sub(body_start + 1) }
local body_length = #body_chunks[1]
-- Keep waiting for data until we have enough.
while body_length < content_length do
---@type string
local chunk = coroutine.yield()
or error('Expected more data for the body. The server may have died.') -- TODO hmm.
table.insert(body_chunks, chunk)
body_length = body_length + #chunk
end
local last_chunk = body_chunks[#body_chunks]
body_chunks[#body_chunks] = last_chunk:sub(1, content_length - body_length - 1) body_chunks[#body_chunks] = last_chunk:sub(1, content_length - body_length - 1)
local rest = '' local rest = ''
if body_length > content_length then if body_length > content_length then
rest = last_chunk:sub(content_length - body_length) rest = last_chunk:sub(content_length - body_length)
end
local body = table.concat(body_chunks)
-- Yield our data.
--- @type string
local data = coroutine.yield(body)
or error('Expected more data for the body. The server may have died.')
buffer = rest .. data
else
-- Get more data since we don't have enough.
--- @type string
local data = coroutine.yield()
or error('Expected more data for the header. The server may have died.')
buffer = buffer .. data
end end
local body = table.concat(body_chunks)
-- Yield our data.
--- @type string
local data = coroutine.yield(headers, body)
or error('Expected more data for the body. The server may have died.')
buffer = rest .. data
else
-- Get more data since we don't have enough.
--- @type string
local data = coroutine.yield()
or error('Expected more data for the header. The server may have died.')
buffer = buffer .. data
end end
end end
end end
local M = {} local M = {}
--- Mapping of error codes used by the client --- Mapping of error codes used by the client
@ -237,7 +256,7 @@ local default_dispatchers = {
--- @param on_exit? fun() --- @param on_exit? fun()
--- @param on_error fun(err: any) --- @param on_error fun(err: any)
function M.create_read_loop(handle_body, on_exit, on_error) function M.create_read_loop(handle_body, on_exit, on_error)
local parse_chunk = coroutine.wrap(request_parser_loop) --[[@as fun(chunk: string?): vim.lsp.rpc.Headers?, string?]] local parse_chunk = coroutine.wrap(request_parser_loop) --[[@as fun(chunk: string?): string]]
parse_chunk() parse_chunk()
return function(err, chunk) return function(err, chunk)
if err then if err then
@ -253,9 +272,9 @@ function M.create_read_loop(handle_body, on_exit, on_error)
end end
while true do while true do
local headers, body = parse_chunk(chunk) local body = parse_chunk(chunk)
if headers then if body then
handle_body(assert(body)) handle_body(body)
chunk = '' chunk = ''
else else
break break