--- /dev/null
+// Copyright 2016 The Prometheus Authors
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "fmt"
+ "net/http"
+ "regexp"
+
+ "github.com/miekg/dns"
+ "github.com/prometheus/common/log"
+)
+
+// validRRs checks a slice of RRs received from the server against a DNSRRValidator.
+func validRRs(rrs *[]dns.RR, v *DNSRRValidator) bool {
+ // Fail the probe if there are no RRs of a given type, but a regexp match is required
+ // (i.e. FailIfNotMatchesRegexp is set).
+ if len(*rrs) == 0 && len(v.FailIfNotMatchesRegexp) > 0 {
+ return false
+ }
+ for _, rr := range *rrs {
+ log.Debugf("Validating RR: %q", rr)
+ for _, re := range v.FailIfMatchesRegexp {
+ match, err := regexp.MatchString(re, rr.String())
+ if err != nil {
+ log.Errorf("Error matching regexp %q: %s", re, err)
+ return false
+ }
+ if match {
+ return false
+ }
+ }
+ for _, re := range v.FailIfNotMatchesRegexp {
+ match, err := regexp.MatchString(re, rr.String())
+ if err != nil {
+ log.Errorf("Error matching regexp %q: %s", re, err)
+ return false
+ }
+ if !match {
+ return false
+ }
+ }
+ }
+ return true
+}
+
+// validRcode checks rcode in the response against a list of valid rcodes.
+func validRcode(rcode int, valid []string) bool {
+ var validRcodes []int
+ // If no list of valid rcodes is specified, only NOERROR is considered valid.
+ if valid == nil {
+ validRcodes = append(validRcodes, dns.StringToRcode["NOERROR"])
+ } else {
+ for _, rcode := range valid {
+ rc, ok := dns.StringToRcode[rcode]
+ if !ok {
+ log.Errorf("Invalid rcode %v. Existing rcodes: %v", rcode, dns.RcodeToString)
+ return false
+ }
+ validRcodes = append(validRcodes, rc)
+ }
+ }
+ for _, rc := range validRcodes {
+ if rcode == rc {
+ return true
+ }
+ }
+ log.Debugf("%s (%d) is not one of the valid rcodes (%v)", dns.RcodeToString[rcode], rcode, validRcodes)
+ return false
+}
+
+func probeDNS(target string, w http.ResponseWriter, module Module) bool {
+ var numAnswer, numAuthority, numAdditional int
+ defer func() {
+ // These metrics can be used to build additional alerting based on the number of replies.
+ // They should be returned even in case of errors.
+ fmt.Fprintf(w, "probe_dns_answer_rrs %d\n", numAnswer)
+ fmt.Fprintf(w, "probe_dns_authority_rrs %d\n", numAuthority)
+ fmt.Fprintf(w, "probe_dns_additional_rrs %d\n", numAdditional)
+ }()
+
+ client := new(dns.Client)
+ client.Net = module.DNS.Protocol
+ client.Timeout = module.Timeout
+
+ qt := dns.TypeANY
+ if module.DNS.QueryType != "" {
+ var ok bool
+ qt, ok = dns.StringToType[module.DNS.QueryType]
+ if !ok {
+ log.Errorf("Invalid type %v. Existing types: %v", module.DNS.QueryType, dns.TypeToString)
+ return false
+ }
+ }
+
+ msg := new(dns.Msg)
+ msg.SetQuestion(dns.Fqdn(module.DNS.QueryName), qt)
+
+ response, _, err := client.Exchange(msg, target)
+ if err != nil {
+ log.Warnf("Error while sending a DNS query: %s", err)
+ return false
+ }
+ log.Debugf("Got response: %#v", response)
+
+ numAnswer, numAuthority, numAdditional = len(response.Answer), len(response.Ns), len(response.Extra)
+
+ if !validRcode(response.Rcode, module.DNS.ValidRcodes) {
+ return false
+ }
+ if !validRRs(&response.Answer, &module.DNS.ValidateAnswer) {
+ log.Debugf("Answer RRs validation failed")
+ return false
+ }
+ if !validRRs(&response.Ns, &module.DNS.ValidateAuthority) {
+ log.Debugf("Authority RRs validation failed")
+ return false
+ }
+ if !validRRs(&response.Extra, &module.DNS.ValidateAdditional) {
+ log.Debugf("Additional RRs validation failed")
+ return false
+ }
+ return true
+}
--- /dev/null
+// Copyright 2016 The Prometheus Authors
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "net"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+var PROTOCOLS = [...]string{"udp", "tcp"}
+
+// startDNSServer starts a DNS server with a given handler function on a random port.
+// Returns the Server object itself as well as the net.Addr corresponding to the server port.
+func startDNSServer(protocol string, handler func(dns.ResponseWriter, *dns.Msg)) (*dns.Server, net.Addr) {
+ h := dns.NewServeMux()
+ h.HandleFunc(".", handler)
+ server := &dns.Server{Addr: ":0", Net: protocol, Handler: h}
+ go server.ListenAndServe()
+ // Wait until PacketConn becomes available, but give up after 1 second.
+ for i := 0; server.PacketConn == nil && i < 200; i++ {
+ if protocol == "tcp" && server.Listener != nil {
+ break
+ }
+ if protocol == "udp" && server.PacketConn != nil {
+ break
+ }
+ time.Sleep(5 * time.Millisecond)
+ }
+ if protocol == "tcp" {
+ return server, server.Listener.Addr()
+ }
+ return server, server.PacketConn.LocalAddr()
+}
+
+func recursiveDNSHandler(w dns.ResponseWriter, r *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(r)
+ answers := []string{
+ "example.com. 3600 IN A 127.0.0.1",
+ "example.com. 3600 IN A 127.0.0.2",
+ }
+ for _, rr := range answers {
+ a, err := dns.NewRR(rr)
+ if err != nil {
+ panic(err)
+ }
+ m.Answer = append(m.Answer, a)
+ }
+ if err := w.WriteMsg(m); err != nil {
+ panic(err)
+ }
+}
+
+func TestRecursiveDNSResponse(t *testing.T) {
+ tests := []struct {
+ Probe DNSProbe
+ ShouldSucceed bool
+ }{
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ }, true,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidRcodes: []string{"SERVFAIL", "NXDOMAIN"},
+ }, false,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidateAnswer: DNSRRValidator{
+ FailIfMatchesRegexp: []string{".*7200.*"},
+ FailIfNotMatchesRegexp: []string{".*3600.*"},
+ },
+ }, true,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidateAuthority: DNSRRValidator{
+ FailIfMatchesRegexp: []string{".*7200.*"},
+ },
+ }, true,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidateAdditional: DNSRRValidator{
+ FailIfNotMatchesRegexp: []string{".*3600.*"},
+ },
+ }, false,
+ },
+ }
+ expectedOutput := []string{
+ "probe_dns_answer_rrs 2\n",
+ "probe_dns_authority_rrs 0\n",
+ "probe_dns_additional_rrs 0\n",
+ }
+
+ for _, protocol := range PROTOCOLS {
+ server, addr := startDNSServer(protocol, recursiveDNSHandler)
+ defer server.Shutdown()
+
+ for i, test := range tests {
+ test.Probe.Protocol = protocol
+ recorder := httptest.NewRecorder()
+ result := probeDNS(addr.String(), recorder, Module{Timeout: time.Second, DNS: test.Probe})
+ if result != test.ShouldSucceed {
+ t.Fatalf("Test %d had unexpected result: %v", i, result)
+ }
+ body := recorder.Body.String()
+ for _, line := range expectedOutput {
+ if !strings.Contains(body, line) {
+ t.Fatalf("Did not find expected output in test %d: %q", i, line)
+ }
+ }
+ }
+ }
+}
+
+func authoritativeDNSHandler(w dns.ResponseWriter, r *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(r)
+
+ a, err := dns.NewRR("example.com. 3600 IN A 127.0.0.1")
+ if err != nil {
+ panic(err)
+ }
+ m.Answer = append(m.Answer, a)
+
+ authority := []string{
+ "example.com. 7200 IN NS ns1.isp.net.",
+ "example.com. 7200 IN NS ns2.isp.net.",
+ }
+ for _, rr := range authority {
+ a, err := dns.NewRR(rr)
+ if err != nil {
+ panic(err)
+ }
+ m.Ns = append(m.Ns, a)
+ }
+
+ additional := []string{
+ "ns1.isp.net. 7200 IN A 127.0.0.1",
+ "ns1.isp.net. 7200 IN AAAA ::1",
+ "ns2.isp.net. 7200 IN A 127.0.0.2",
+ }
+ for _, rr := range additional {
+ a, err := dns.NewRR(rr)
+ if err != nil {
+ panic(err)
+ }
+ m.Extra = append(m.Extra, a)
+ }
+
+ if err := w.WriteMsg(m); err != nil {
+ panic(err)
+ }
+}
+
+func TestAuthoritativeDNSResponse(t *testing.T) {
+ tests := []struct {
+ Probe DNSProbe
+ ShouldSucceed bool
+ }{
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ }, true,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidRcodes: []string{"SERVFAIL", "NXDOMAIN"},
+ }, false,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidateAnswer: DNSRRValidator{
+ FailIfMatchesRegexp: []string{".*3600.*"},
+ FailIfNotMatchesRegexp: []string{".*3600.*"},
+ },
+ }, false,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidateAnswer: DNSRRValidator{
+ FailIfMatchesRegexp: []string{".*7200.*"},
+ FailIfNotMatchesRegexp: []string{".*7200.*"},
+ },
+ }, false,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidateAuthority: DNSRRValidator{
+ FailIfNotMatchesRegexp: []string{"ns.*.isp.net"},
+ },
+ }, true,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidateAdditional: DNSRRValidator{
+ FailIfNotMatchesRegexp: []string{"^ns.*.isp"},
+ },
+ }, true,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidateAdditional: DNSRRValidator{
+ FailIfMatchesRegexp: []string{"^ns.*.isp"},
+ },
+ }, false,
+ },
+ }
+ expectedOutput := []string{
+ "probe_dns_answer_rrs 1\n",
+ "probe_dns_authority_rrs 2\n",
+ "probe_dns_additional_rrs 3\n",
+ }
+
+ for _, protocol := range PROTOCOLS {
+ server, addr := startDNSServer(protocol, authoritativeDNSHandler)
+ defer server.Shutdown()
+
+ for i, test := range tests {
+ test.Probe.Protocol = protocol
+ recorder := httptest.NewRecorder()
+ result := probeDNS(addr.String(), recorder, Module{Timeout: time.Second, DNS: test.Probe})
+ if result != test.ShouldSucceed {
+ t.Fatalf("Test %d had unexpected result: %v", i, result)
+ }
+ body := recorder.Body.String()
+ for _, line := range expectedOutput {
+ if !strings.Contains(body, line) {
+ t.Fatalf("Did not find expected output in test %d: %q", i, line)
+ }
+ }
+ }
+ }
+}
+
+func TestServfailDNSResponse(t *testing.T) {
+ tests := []struct {
+ Probe DNSProbe
+ ShouldSucceed bool
+ }{
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ }, false,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidRcodes: []string{"SERVFAIL", "NXDOMAIN"},
+ }, true,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ QueryType: "NOT_A_VALID_QUERY_TYPE",
+ }, false,
+ },
+ {
+ DNSProbe{
+ QueryName: "example.com",
+ ValidRcodes: []string{"NOT_A_VALID_RCODE"},
+ }, false,
+ },
+ }
+ expectedOutput := []string{
+ "probe_dns_answer_rrs 0\n",
+ "probe_dns_authority_rrs 0\n",
+ "probe_dns_additional_rrs 0\n",
+ }
+
+ for _, protocol := range PROTOCOLS {
+ // dns.HandleFailed returns SERVFAIL on everything
+ server, addr := startDNSServer(protocol, dns.HandleFailed)
+ defer server.Shutdown()
+
+ for i, test := range tests {
+ test.Probe.Protocol = protocol
+ recorder := httptest.NewRecorder()
+ result := probeDNS(addr.String(), recorder, Module{Timeout: time.Second, DNS: test.Probe})
+ if result != test.ShouldSucceed {
+ t.Fatalf("Test %d had unexpected result: %v", i, result)
+ }
+ body := recorder.Body.String()
+ for _, line := range expectedOutput {
+ if !strings.Contains(body, line) {
+ t.Fatalf("Did not find expected output in test %d: %q", i, line)
+ }
+ }
+ }
+ }
+}