diff --git a/cmd/meshnamed/main.go b/cmd/meshnamed/main.go index 590dacd..9ee4a57 100644 --- a/cmd/meshnamed/main.go +++ b/cmd/meshnamed/main.go @@ -71,14 +71,12 @@ func main() { return } - s := meshname.New(logger, listenAddr) - - if networks, err := parseNetworks(networksconf); err == nil { - s.ConfigureNetworks(networks) - } else { + networks, err := parseNetworks(networksconf) + if err != nil { logger.Fatalln(err) } + s := meshname.New(logger, listenAddr, networks) if useconffile != "" { if err := loadConfig(s, useconffile); err != nil { logger.Fatalln(err) diff --git a/pkg/meshname/server.go b/pkg/meshname/server.go index 82c8ea1..669c91c 100644 --- a/pkg/meshname/server.go +++ b/pkg/meshname/server.go @@ -25,7 +25,7 @@ type MeshnameServer struct { } // New is a constructor for MeshnameServer -func New(log *log.Logger, listenAddr string) *MeshnameServer { +func New(log *log.Logger, listenAddr string, networks map[string]*net.IPNet) *MeshnameServer { dnsClient := new(dns.Client) dnsClient.Timeout = 5000000000 // increased 5 seconds timeout @@ -33,7 +33,7 @@ func New(log *log.Logger, listenAddr string) *MeshnameServer { log: log, listenAddr: listenAddr, dnsRecords: make(map[string][]dns.RR), - networks: make(map[string]*net.IPNet), + networks: networks, dnsClient: dnsClient, } } @@ -55,12 +55,19 @@ func (s *MeshnameServer) Start() error { defer s.startedLock.Unlock() if !s.started { - s.dnsServer = &dns.Server{Addr: s.listenAddr, Net: "udp"} + waitStarted := make(chan struct{}) + s.dnsServer = &dns.Server{ + Addr: s.listenAddr, + Net: "udp", + NotifyStartedFunc: func(){ close(waitStarted) }, + } for tld, subnet := range s.networks { dns.HandleFunc(tld, s.handleRequest) s.log.Debugln("Handling:", tld, subnet) } go s.dnsServer.ListenAndServe() + <-waitStarted + s.log.Debugln("MeshnameServer started") s.started = true return nil @@ -75,10 +82,6 @@ func (s *MeshnameServer) ConfigureDNSRecords(dnsRecords map[string][]dns.RR) { s.dnsRecordsLock.Unlock() } -func (s *MeshnameServer) ConfigureNetworks(networks map[string]*net.IPNet) { - s.networks = networks -} - func (s *MeshnameServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) { var remoteLookups = make(map[string][]dns.Question) m := new(dns.Msg) diff --git a/pkg/meshname/server_test.go b/pkg/meshname/server_test.go index 7179f8d..7bc5cc4 100644 --- a/pkg/meshname/server_test.go +++ b/pkg/meshname/server_test.go @@ -15,10 +15,10 @@ import ( func TestServerLocalDomain(t *testing.T) { bindAddr := "[::1]:54545" log := log.New(os.Stdout, "", log.Flags()) - - ts := meshname.New(log, bindAddr) yggIPNet := &net.IPNet{IP: net.ParseIP("200::"), Mask: net.CIDRMask(7, 128)} - ts.ConfigureNetworks(map[string]*net.IPNet{"ygg": yggIPNet, "meshname": yggIPNet}) + networks := map[string]*net.IPNet{"meshname": yggIPNet} + + ts := meshname.New(log, bindAddr, networks) exampleConfig := make(map[string][]string) exampleConfig["aiarnf2wpqjxkp6rhivuxbondy"] = append(exampleConfig["aiarnf2wpqjxkp6rhivuxbondy"],