syncthing/cmd/discosrv/main.go

370 lines
7.5 KiB
Go
Raw Normal View History

2014-07-12 15:45:33 -07:00
// Copyright (C) 2014 Jakob Borg and Contributors (see the CONTRIBUTORS file).
// All rights reserved. Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
2014-06-01 13:50:14 -07:00
2013-12-22 19:35:05 -07:00
package main
import (
2014-09-08 02:48:26 -07:00
"bytes"
2014-02-20 09:40:15 -07:00
"encoding/binary"
"encoding/hex"
"flag"
2014-04-19 14:14:56 -07:00
"fmt"
"io"
2013-12-22 19:35:05 -07:00
"log"
"net"
2014-02-20 09:40:15 -07:00
"os"
2014-09-08 02:48:26 -07:00
"path/filepath"
2013-12-22 19:35:05 -07:00
"sync"
"time"
2013-12-22 19:35:05 -07:00
2014-04-03 14:38:32 -07:00
"github.com/golang/groupcache/lru"
"github.com/juju/ratelimit"
"github.com/syncthing/syncthing/discover"
"github.com/syncthing/syncthing/protocol"
2014-09-08 02:48:26 -07:00
"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/opt"
2013-12-22 19:35:05 -07:00
)
2014-09-08 02:48:26 -07:00
const cacheLimitSeconds = 3600
2013-12-22 19:35:05 -07:00
var (
2014-06-27 13:39:03 -07:00
lock sync.Mutex
queries = 0
announces = 0
answered = 0
limited = 0
unknowns = 0
debug = false
lruSize = 1024
limitAvg = 1
limitBurst = 10
limiter *lru.Cache
2013-12-22 19:35:05 -07:00
)
func main() {
2014-02-20 09:40:15 -07:00
var listen string
var timestamp bool
2014-04-19 14:14:56 -07:00
var statsIntv int
var statsFile string
2014-09-08 02:48:26 -07:00
var dbDir string
2014-02-20 09:40:15 -07:00
flag.StringVar(&listen, "listen", ":22026", "Listen address")
2014-02-20 09:40:15 -07:00
flag.BoolVar(&debug, "debug", false, "Enable debug output")
flag.BoolVar(&timestamp, "timestamp", true, "Timestamp the log output")
2014-04-19 14:14:56 -07:00
flag.IntVar(&statsIntv, "stats-intv", 0, "Statistics output interval (s)")
2014-09-08 02:48:26 -07:00
flag.StringVar(&statsFile, "stats-file", "/var/discosrv/stats", "Statistics file name")
2014-06-27 13:39:03 -07:00
flag.IntVar(&lruSize, "limit-cache", lruSize, "Limiter cache entries")
flag.IntVar(&limitAvg, "limit-avg", limitAvg, "Allowed average package rate, per 10 s")
flag.IntVar(&limitBurst, "limit-burst", limitBurst, "Allowed burst size, packets")
2014-09-08 02:48:26 -07:00
flag.StringVar(&dbDir, "db-dir", "/var/discosrv/db", "Database directory")
2014-02-20 09:40:15 -07:00
flag.Parse()
2014-06-27 13:39:03 -07:00
limiter = lru.New(lruSize)
2014-02-20 09:40:15 -07:00
log.SetOutput(os.Stdout)
if !timestamp {
log.SetFlags(0)
}
addr, _ := net.ResolveUDPAddr("udp", listen)
2013-12-22 19:35:05 -07:00
conn, err := net.ListenUDP("udp", addr)
if err != nil {
2014-04-03 13:44:40 -07:00
log.Fatal(err)
2013-12-22 19:35:05 -07:00
}
2014-09-08 02:48:26 -07:00
parentDir := filepath.Dir(dbDir)
if _, err := os.Stat(parentDir); err != nil && os.IsNotExist(err) {
err = os.MkdirAll(parentDir, 0755)
if err != nil {
log.Fatal(err)
}
}
db, err := leveldb.OpenFile(dbDir, &opt.Options{CachedOpenFiles: 32})
if err != nil {
log.Fatal(err)
}
statsLog, err := os.OpenFile(statsFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
log.Fatal(err)
}
2014-04-19 14:14:56 -07:00
if statsIntv > 0 {
2014-09-08 02:48:26 -07:00
go logStats(statsLog, statsIntv)
2014-04-19 14:14:56 -07:00
}
2014-09-08 02:48:26 -07:00
go clean(statsLog, db)
2013-12-22 19:35:05 -07:00
var buf = make([]byte, 1024)
for {
2014-02-20 09:40:15 -07:00
buf = buf[:cap(buf)]
2013-12-22 19:35:05 -07:00
n, addr, err := conn.ReadFromUDP(buf)
2014-04-03 14:38:32 -07:00
if limit(addr) {
// Rate limit in effect for source
continue
}
2013-12-22 19:35:05 -07:00
if err != nil {
2014-04-03 13:44:40 -07:00
log.Fatal(err)
2013-12-22 19:35:05 -07:00
}
2014-04-03 14:38:32 -07:00
2014-02-20 09:40:15 -07:00
if n < 4 {
log.Printf("Received short packet (%d bytes)", n)
2013-12-22 19:35:05 -07:00
continue
}
2014-02-20 09:40:15 -07:00
buf = buf[:n]
magic := binary.BigEndian.Uint32(buf)
switch magic {
case discover.AnnouncementMagic:
2014-09-08 02:48:26 -07:00
handleAnnounceV2(db, addr, buf)
2014-02-20 09:40:15 -07:00
case discover.QueryMagic:
2014-09-08 02:48:26 -07:00
handleQueryV2(db, conn, addr, buf)
2014-04-19 14:14:56 -07:00
default:
lock.Lock()
unknowns++
lock.Unlock()
2014-04-03 13:44:40 -07:00
}
}
}
2014-02-20 09:40:15 -07:00
2014-04-03 14:38:32 -07:00
func limit(addr *net.UDPAddr) bool {
key := addr.IP.String()
lock.Lock()
defer lock.Unlock()
bkt, ok := limiter.Get(key)
if ok {
bkt := bkt.(*ratelimit.Bucket)
if bkt.TakeAvailable(1) != 1 {
// Rate limit exceeded; ignore packet
if debug {
2014-04-16 06:06:54 -07:00
log.Println("Rate limit exceeded for", key)
2014-04-03 14:38:32 -07:00
}
limited++
return true
}
} else {
if debug {
2014-04-16 06:06:54 -07:00
log.Println("New limiter for", key)
2014-04-03 14:38:32 -07:00
}
// One packet per ten seconds average rate, burst ten packets
2014-06-27 13:39:03 -07:00
limiter.Add(key, ratelimit.NewBucket(10*time.Second/time.Duration(limitAvg), int64(limitBurst)))
2014-04-03 14:38:32 -07:00
}
return false
}
2014-09-08 02:48:26 -07:00
func handleAnnounceV2(db *leveldb.DB, addr *net.UDPAddr, buf []byte) {
var pkt discover.Announce
2014-04-03 13:44:40 -07:00
err := pkt.UnmarshalXDR(buf)
if err != nil && err != io.EOF {
2014-04-03 13:44:40 -07:00
log.Println("AnnounceV2 Unmarshal:", err)
log.Println(hex.Dump(buf))
return
}
if debug {
log.Printf("<- %v %#v", addr, pkt)
}
2014-04-19 14:14:56 -07:00
lock.Lock()
announces++
lock.Unlock()
2014-04-03 13:44:40 -07:00
ip := addr.IP.To4()
if ip == nil {
ip = addr.IP.To16()
}
var addrs []address
2014-09-08 02:48:26 -07:00
now := time.Now().Unix()
for _, addr := range pkt.This.Addresses {
2014-04-03 13:44:40 -07:00
tip := addr.IP
if len(tip) == 0 {
tip = ip
}
addrs = append(addrs, address{
ip: tip,
port: addr.Port,
2014-09-08 02:48:26 -07:00
seen: now,
2014-04-03 13:44:40 -07:00
})
}
var id protocol.NodeID
if len(pkt.This.ID) == 32 {
// Raw node ID
copy(id[:], pkt.This.ID)
} else {
id.UnmarshalText(pkt.This.ID)
}
2014-09-08 02:48:26 -07:00
update(db, id, addrs)
2014-04-03 13:44:40 -07:00
}
2014-09-08 02:48:26 -07:00
func handleQueryV2(db *leveldb.DB, conn *net.UDPConn, addr *net.UDPAddr, buf []byte) {
var pkt discover.Query
2014-04-03 13:44:40 -07:00
err := pkt.UnmarshalXDR(buf)
if err != nil {
log.Println("QueryV2 Unmarshal:", err)
log.Println(hex.Dump(buf))
return
}
if debug {
log.Printf("<- %v %#v", addr, pkt)
}
var id protocol.NodeID
if len(pkt.NodeID) == 32 {
// Raw node ID
copy(id[:], pkt.NodeID)
} else {
id.UnmarshalText(pkt.NodeID)
}
2014-04-03 13:44:40 -07:00
lock.Lock()
queries++
lock.Unlock()
2014-09-08 02:48:26 -07:00
addrs := get(db, id)
now := time.Now().Unix()
if len(addrs) > 0 {
ann := discover.Announce{
Magic: discover.AnnouncementMagic,
This: discover.Node{
ID: pkt.NodeID,
},
2014-04-03 13:44:40 -07:00
}
2014-09-08 02:48:26 -07:00
for _, addr := range addrs {
if now-addr.seen > cacheLimitSeconds {
continue
}
ann.This.Addresses = append(ann.This.Addresses, discover.Address{IP: addr.ip, Port: addr.port})
2014-04-03 13:44:40 -07:00
}
if debug {
log.Printf("-> %v %#v", addr, pkt)
}
2014-09-08 02:48:26 -07:00
if len(ann.This.Addresses) == 0 {
return
}
tb := ann.MarshalXDR()
2014-04-03 13:44:40 -07:00
_, _, err = conn.WriteMsgUDP(tb, nil, addr)
if err != nil {
log.Println("QueryV2 response write:", err)
}
lock.Lock()
answered++
lock.Unlock()
}
}
2014-04-19 14:14:56 -07:00
func next(intv int) time.Time {
d := time.Duration(intv) * time.Second
t0 := time.Now()
t1 := t0.Add(d).Truncate(d)
time.Sleep(t1.Sub(t0))
return t1
}
2014-09-08 02:48:26 -07:00
func logStats(statsLog io.Writer, intv int) {
2014-04-03 13:44:40 -07:00
for {
2014-04-19 14:14:56 -07:00
t := next(intv)
2014-04-03 13:44:40 -07:00
lock.Lock()
2014-09-08 02:48:26 -07:00
fmt.Fprintf(statsLog, "%d Queries:%d Answered:%d Announces:%d Unknown:%d Limited:%d\n",
t.Unix(), queries, answered, announces, unknowns, limited)
2014-04-19 14:14:56 -07:00
2014-04-03 13:44:40 -07:00
queries = 0
2014-04-19 14:14:56 -07:00
announces = 0
2014-04-03 13:44:40 -07:00
answered = 0
2014-04-03 14:38:32 -07:00
limited = 0
2014-04-19 14:14:56 -07:00
unknowns = 0
2014-04-03 13:44:40 -07:00
lock.Unlock()
2013-12-22 19:35:05 -07:00
}
}
2014-09-08 02:48:26 -07:00
func get(db *leveldb.DB, id protocol.NodeID) []address {
var addrs addressList
val, err := db.Get(id[:], nil)
if err == nil {
addrs.UnmarshalXDR(val)
}
return addrs.addresses
}
func update(db *leveldb.DB, id protocol.NodeID, addrs []address) {
var newAddrs addressList
val, err := db.Get(id[:], nil)
if err == nil {
newAddrs.UnmarshalXDR(val)
}
nextAddr:
for _, newAddr := range addrs {
for i, exAddr := range newAddrs.addresses {
if bytes.Compare(newAddr.ip, exAddr.ip) == 0 {
newAddrs.addresses[i] = newAddr
continue nextAddr
}
}
newAddrs.addresses = append(newAddrs.addresses, newAddr)
}
db.Put(id[:], newAddrs.MarshalXDR(), nil)
}
func clean(statsLog io.Writer, db *leveldb.DB) {
for {
now := time.Now()
nowSecs := now.Unix()
var kept, deleted int64
iter := db.NewIterator(nil, nil)
for iter.Next() {
var addrs addressList
addrs.UnmarshalXDR(iter.Value())
// Remove expired addresses
newAddrs := addrs.addresses
for i := 0; i < len(newAddrs); i++ {
if nowSecs-newAddrs[i].seen > cacheLimitSeconds {
newAddrs[i] = newAddrs[len(newAddrs)-1]
newAddrs = newAddrs[:len(newAddrs)-1]
}
}
// Delete empty records
if len(newAddrs) == 0 {
db.Delete(iter.Key(), nil)
deleted++
continue
}
// Update changed records
if len(newAddrs) != len(addrs.addresses) {
addrs.addresses = newAddrs
db.Put(iter.Key(), addrs.MarshalXDR(), nil)
}
kept++
}
iter.Release()
fmt.Fprintf(statsLog, "%d Kept:%d Deleted:%d Took:%0.04fs\n", nowSecs, kept, deleted, time.Since(now).Seconds())
time.Sleep(cacheLimitSeconds * time.Second / 2)
}
}