mirror of
https://github.com/neovim/neovim.git
synced 2024-12-20 11:15:14 -07:00
sockets: don't deadlock when connecting to own pipe address
This commit is contained in:
parent
6a75938758
commit
5a151555c8
@ -12,6 +12,7 @@
|
||||
#include "nvim/api/vim.h"
|
||||
#include "nvim/api/ui.h"
|
||||
#include "nvim/msgpack_rpc/channel.h"
|
||||
#include "nvim/msgpack_rpc/server.h"
|
||||
#include "nvim/event/loop.h"
|
||||
#include "nvim/event/libuv_process.h"
|
||||
#include "nvim/event/rstream.h"
|
||||
@ -28,6 +29,7 @@
|
||||
#include "nvim/map.h"
|
||||
#include "nvim/log.h"
|
||||
#include "nvim/misc1.h"
|
||||
#include "nvim/path.h"
|
||||
#include "nvim/lib/kvec.h"
|
||||
#include "nvim/os/input.h"
|
||||
|
||||
@ -41,7 +43,8 @@
|
||||
typedef enum {
|
||||
kChannelTypeSocket,
|
||||
kChannelTypeProc,
|
||||
kChannelTypeStdio
|
||||
kChannelTypeStdio,
|
||||
kChannelTypeInternal
|
||||
} ChannelType;
|
||||
|
||||
typedef struct {
|
||||
@ -125,7 +128,7 @@ uint64_t channel_from_process(Process *proc, uint64_t id)
|
||||
|
||||
wstream_init(proc->in, 0);
|
||||
rstream_init(proc->out, 0);
|
||||
rstream_start(proc->out, parse_msgpack, channel);
|
||||
rstream_start(proc->out, receive_msgpack, channel);
|
||||
|
||||
return channel->id;
|
||||
}
|
||||
@ -142,12 +145,22 @@ void channel_from_connection(SocketWatcher *watcher)
|
||||
channel->data.stream.internal_data = channel;
|
||||
wstream_init(&channel->data.stream, 0);
|
||||
rstream_init(&channel->data.stream, CHANNEL_BUFFER_SIZE);
|
||||
rstream_start(&channel->data.stream, parse_msgpack, channel);
|
||||
rstream_start(&channel->data.stream, receive_msgpack, channel);
|
||||
}
|
||||
|
||||
uint64_t channel_connect(bool tcp, const char *address,
|
||||
int timeout, const char **error)
|
||||
{
|
||||
if (!tcp) {
|
||||
char *path = fix_fname(address);
|
||||
if (server_owns_pipe_address(path)) {
|
||||
// avoid deadlock
|
||||
xfree(path);
|
||||
return channel_create_internal();
|
||||
}
|
||||
xfree(path);
|
||||
}
|
||||
|
||||
Channel *channel = register_channel(kChannelTypeSocket, 0, NULL);
|
||||
if (!socket_connect(&main_loop, &channel->data.stream,
|
||||
tcp, address, timeout, error)) {
|
||||
@ -160,7 +173,7 @@ uint64_t channel_connect(bool tcp, const char *address,
|
||||
channel->data.stream.internal_data = channel;
|
||||
wstream_init(&channel->data.stream, 0);
|
||||
rstream_init(&channel->data.stream, CHANNEL_BUFFER_SIZE);
|
||||
rstream_start(&channel->data.stream, parse_msgpack, channel);
|
||||
rstream_start(&channel->data.stream, receive_msgpack, channel);
|
||||
return channel->id;
|
||||
}
|
||||
|
||||
@ -324,11 +337,20 @@ void channel_from_stdio(void)
|
||||
incref(channel); // stdio channels are only closed on exit
|
||||
// read stream
|
||||
rstream_init_fd(&main_loop, &channel->data.std.in, 0, CHANNEL_BUFFER_SIZE);
|
||||
rstream_start(&channel->data.std.in, parse_msgpack, channel);
|
||||
rstream_start(&channel->data.std.in, receive_msgpack, channel);
|
||||
// write stream
|
||||
wstream_init_fd(&main_loop, &channel->data.std.out, 1, 0);
|
||||
}
|
||||
|
||||
/// Creates a loopback channel. This is used to avoid deadlock
|
||||
/// when an instance connects to its own named pipe.
|
||||
uint64_t channel_create_internal(void)
|
||||
{
|
||||
Channel *channel = register_channel(kChannelTypeInternal, 0, NULL);
|
||||
incref(channel); // internal channel lives until process exit
|
||||
return channel->id;
|
||||
}
|
||||
|
||||
void channel_process_exit(uint64_t id, int status)
|
||||
{
|
||||
Channel *channel = pmap_get(uint64_t)(channels, id);
|
||||
@ -337,8 +359,8 @@ void channel_process_exit(uint64_t id, int status)
|
||||
decref(channel);
|
||||
}
|
||||
|
||||
static void parse_msgpack(Stream *stream, RBuffer *rbuf, size_t c, void *data,
|
||||
bool eof)
|
||||
static void receive_msgpack(Stream *stream, RBuffer *rbuf, size_t c,
|
||||
void *data, bool eof)
|
||||
{
|
||||
Channel *channel = data;
|
||||
incref(channel);
|
||||
@ -360,6 +382,14 @@ static void parse_msgpack(Stream *stream, RBuffer *rbuf, size_t c, void *data,
|
||||
rbuffer_read(rbuf, msgpack_unpacker_buffer(channel->unpacker), count);
|
||||
msgpack_unpacker_buffer_consumed(channel->unpacker, count);
|
||||
|
||||
parse_msgpack(channel);
|
||||
|
||||
end:
|
||||
decref(channel);
|
||||
}
|
||||
|
||||
static void parse_msgpack(Channel *channel)
|
||||
{
|
||||
msgpack_unpacked unpacked;
|
||||
msgpack_unpacked_init(&unpacked);
|
||||
msgpack_unpack_return result;
|
||||
@ -383,7 +413,7 @@ static void parse_msgpack(Stream *stream, RBuffer *rbuf, size_t c, void *data,
|
||||
}
|
||||
msgpack_unpacked_destroy(&unpacked);
|
||||
// Bail out from this event loop iteration
|
||||
goto end;
|
||||
return;
|
||||
}
|
||||
|
||||
handle_request(channel, &unpacked.data);
|
||||
@ -407,11 +437,9 @@ static void parse_msgpack(Stream *stream, RBuffer *rbuf, size_t c, void *data,
|
||||
"This error can also happen when deserializing "
|
||||
"an object with high level of nesting");
|
||||
}
|
||||
|
||||
end:
|
||||
decref(channel);
|
||||
}
|
||||
|
||||
|
||||
static void handle_request(Channel *channel, msgpack_object *request)
|
||||
FUNC_ATTR_NONNULL_ALL
|
||||
{
|
||||
@ -521,8 +549,11 @@ static bool channel_write(Channel *channel, WBuffer *buffer)
|
||||
case kChannelTypeStdio:
|
||||
success = wstream_write(&channel->data.std.out, buffer);
|
||||
break;
|
||||
default:
|
||||
abort();
|
||||
case kChannelTypeInternal:
|
||||
incref(channel);
|
||||
CREATE_EVENT(channel->events, internal_read_event, 2, channel, buffer);
|
||||
success = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
@ -539,6 +570,22 @@ static bool channel_write(Channel *channel, WBuffer *buffer)
|
||||
return success;
|
||||
}
|
||||
|
||||
static void internal_read_event(void **argv)
|
||||
{
|
||||
Channel *channel = argv[0];
|
||||
WBuffer *buffer = argv[1];
|
||||
|
||||
msgpack_unpacker_reserve_buffer(channel->unpacker, buffer->size);
|
||||
memcpy(msgpack_unpacker_buffer(channel->unpacker),
|
||||
buffer->data, buffer->size);
|
||||
msgpack_unpacker_buffer_consumed(channel->unpacker, buffer->size);
|
||||
|
||||
parse_msgpack(channel);
|
||||
|
||||
decref(channel);
|
||||
wstream_release_wbuffer(buffer);
|
||||
}
|
||||
|
||||
static void send_error(Channel *channel, uint64_t id, char *err)
|
||||
{
|
||||
Error e = ERROR_INIT;
|
||||
@ -655,8 +702,9 @@ static void close_channel(Channel *channel)
|
||||
stream_close(&channel->data.std.out, NULL, NULL);
|
||||
multiqueue_put(main_loop.fast_events, exit_event, 1, channel);
|
||||
return;
|
||||
default:
|
||||
abort();
|
||||
case kChannelTypeInternal:
|
||||
// nothing to free.
|
||||
break;
|
||||
}
|
||||
|
||||
decref(channel);
|
||||
|
@ -97,6 +97,18 @@ char *server_address_new(void)
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Check if this instance owns a pipe address.
|
||||
/// The argument must already be resolved to an absolute path!
|
||||
bool server_owns_pipe_address(const char *path)
|
||||
{
|
||||
for (int i = 0; i < watchers.ga_len; i++) {
|
||||
if (!strcmp(path, ((SocketWatcher **)watchers.ga_data)[i]->addr)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Starts listening for API calls.
|
||||
///
|
||||
/// The socket type is determined by parsing `endpoint`: If it's a valid IPv4
|
||||
|
@ -1715,7 +1715,7 @@ int vim_FullName(const char *fname, char *buf, size_t len, bool force)
|
||||
///
|
||||
/// @param fname is the filename to expand
|
||||
/// @return [allocated] Full path (NULL for failure).
|
||||
char *fix_fname(char *fname)
|
||||
char *fix_fname(const char *fname)
|
||||
{
|
||||
#ifdef UNIX
|
||||
return FullName_save(fname, true);
|
||||
|
@ -282,4 +282,20 @@ describe('server -> client', function()
|
||||
end)
|
||||
end)
|
||||
|
||||
describe('when connecting to its own pipe adress', function()
|
||||
it('it does not deadlock', function()
|
||||
local address = funcs.serverlist()[1]
|
||||
local first = string.sub(address,1,1)
|
||||
ok(first == '/' or first == '\\')
|
||||
local serverpid = funcs.getpid()
|
||||
|
||||
local id = funcs.sockconnect('pipe', address, {rpc=true})
|
||||
|
||||
funcs.rpcrequest(id, 'nvim_set_current_line', 'hello')
|
||||
eq('hello', meths.get_current_line())
|
||||
eq(serverpid, funcs.rpcrequest(id, "nvim_eval", "getpid()"))
|
||||
|
||||
eq(id, funcs.rpcrequest(id, 'nvim_get_api_info')[1])
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
Loading…
Reference in New Issue
Block a user