Update config, signal handling and refactoring

This commit is contained in:
George 2020-03-08 00:59:44 -05:00
parent 041fdff2a8
commit a2be79ddcb
3 changed files with 63 additions and 32 deletions

View File

@ -3,8 +3,8 @@ package main
import ( import (
"fmt" "fmt"
"net" "net"
"strings"
"os" "os"
"strings"
"github.com/zhoreeq/meshname/src/meshname" "github.com/zhoreeq/meshname/src/meshname"
) )
@ -28,7 +28,7 @@ func main() {
fmt.Println("Invalid domain") fmt.Println("Invalid domain")
return return
} }
subDomain := labels[len(labels) - 2] subDomain := labels[len(labels)-2]
if len(subDomain) != 26 { if len(subDomain) != 26 {
fmt.Println("Invalid subdomain length") fmt.Println("Invalid subdomain length")
return return

View File

@ -3,8 +3,10 @@ package main
import ( import (
"net" "net"
"os" "os"
"os/signal"
"fmt" "fmt"
"flag" "flag"
"syscall"
"github.com/gologme/log" "github.com/gologme/log"
@ -47,9 +49,24 @@ func main() {
os.Exit(1) os.Exit(1)
} }
s.Init(logger, meshname.MeshnameOptions{ListenAddr: *listenAddr, ConfigPath: *useconffile, ValidSubnet: validSubnet}) s.Init(logger, *listenAddr, *useconffile, validSubnet)
s.Start() s.Start()
c := make(chan os.Signal, 1)
r := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
signal.Notify(r, os.Interrupt, syscall.SIGHUP)
defer s.Stop()
for {
select {
case _ = <-c:
goto exit
case _ = <-r:
s.UpdateConfig()
}
}
default: default:
flag.PrintDefaults() flag.PrintDefaults()
} }
exit:
} }

View File

@ -49,24 +49,19 @@ func GenConf(target string) (string, error) {
} }
type MeshnameServer struct { type MeshnameServer struct {
validSubnet *net.IPNet validSubnet *net.IPNet
log *log.Logger log *log.Logger
listenAddr, zoneConfigPath string listenAddr, zoneConfigPath string
zoneConfig map[string][]dns.RR zoneConfig map[string][]dns.RR
dnsClient *dns.Client dnsClient *dns.Client
dnsServer *dns.Server
} }
type MeshnameOptions struct { func (s *MeshnameServer) Init(log *log.Logger, listenAddr string, zoneConfigPath string, validSubnet *net.IPNet) {
ListenAddr, ConfigPath string
ValidSubnet *net.IPNet
}
func (s *MeshnameServer) Init(log *log.Logger, options interface{}) {
mnoptions := options.(MeshnameOptions)
s.log = log s.log = log
s.listenAddr = mnoptions.ListenAddr s.listenAddr = listenAddr
s.validSubnet = mnoptions.ValidSubnet s.validSubnet = validSubnet
s.zoneConfigPath = mnoptions.ConfigPath s.zoneConfigPath = zoneConfigPath
s.zoneConfig = make(map[string][]dns.RR) s.zoneConfig = make(map[string][]dns.RR)
if s.dnsClient == nil { if s.dnsClient == nil {
s.dnsClient = new(dns.Client) s.dnsClient = new(dns.Client)
@ -115,11 +110,19 @@ func (s *MeshnameServer) LoadConfig() {
s.log.Infoln("Meshname config loaded:", s.zoneConfigPath) s.log.Infoln("Meshname config loaded:", s.zoneConfigPath)
} }
func (s *MeshnameServer) Start() { func (s *MeshnameServer) Stop() error {
dnsServer := &dns.Server{Addr: s.listenAddr, Net: "udp"} if s.dnsServer != nil {
s.log.Infoln("Started meshnamed on:", s.listenAddr) s.dnsServer.Shutdown()
}
return nil
}
func (s *MeshnameServer) Start() error {
s.dnsServer = &dns.Server{Addr: s.listenAddr, Net: "udp"}
dns.HandleFunc(DomainZone, s.handleRequest) dns.HandleFunc(DomainZone, s.handleRequest)
dnsServer.ListenAndServe() go s.dnsServer.ListenAndServe()
s.log.Infoln("Started meshnamed on:", s.listenAddr)
return nil
} }
func (s *MeshnameServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) { func (s *MeshnameServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
@ -135,24 +138,23 @@ func (s *MeshnameServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
} }
subDomain := labels[len(labels)-2] subDomain := labels[len(labels)-2]
resolvedAddr, err := IPFromDomain(subDomain)
if err != nil {
s.log.Debugln(err)
continue
}
if !s.validSubnet.Contains(resolvedAddr) {
s.log.Debugln("Error: subnet doesn't match")
continue
}
if records, ok := s.zoneConfig[subDomain]; ok { if records, ok := s.zoneConfig[subDomain]; ok {
for _, rec := range records { for _, rec := range records {
if h := rec.Header(); h.Name == q.Name && h.Rrtype == q.Qtype && h.Class == q.Qclass { if h := rec.Header(); h.Name == q.Name && h.Rrtype == q.Qtype && h.Class == q.Qclass {
m.Answer = append(m.Answer, rec) m.Answer = append(m.Answer, rec)
} }
} }
} else if ra := w.RemoteAddr().String(); strings.HasPrefix(ra, "[::1]:") || strings.HasPrefix(ra, "127.0.0.1:") { } else if s.isRemoteLookupAllowed(w.RemoteAddr()) {
// TODO prefix whitelists ?
// do remote lookups only for local clients // do remote lookups only for local clients
resolvedAddr, err := IPFromDomain(subDomain)
if err != nil {
s.log.Debugln(err)
continue
}
if !s.validSubnet.Contains(resolvedAddr) {
s.log.Debugln("Error: subnet doesn't match")
continue
}
remoteLookups[resolvedAddr.String()] = append(remoteLookups[resolvedAddr.String()], q) remoteLookups[resolvedAddr.String()] = append(remoteLookups[resolvedAddr.String()], q)
} }
} }
@ -170,3 +172,15 @@ func (s *MeshnameServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m) w.WriteMsg(m)
} }
func (s *MeshnameServer) isRemoteLookupAllowed(addr net.Addr) bool {
// TODO prefix whitelists ?
ra := addr.String()
return strings.HasPrefix(ra, "[::1]:") || strings.HasPrefix(ra, "127.0.0.1:")
}
func (s *MeshnameServer) UpdateConfig() error {
s.Stop()
s.LoadConfig()
s.Start()
return nil
}