Factor out common resolution logic
authorConor Broderick <conor.broderick@robustperception.io>
Wed, 14 Jun 2017 15:48:27 +0000 (16:48 +0100)
committerBrian Brazil <brian.brazil@robustperception.io>
Wed, 14 Jun 2017 15:48:27 +0000 (16:48 +0100)
Have common metrics and configuration for protocol selection.

This changes the config file format.

README.md
dns.go
dns_test.go
http.go
http_test.go
icmp.go
main.go
tcp.go
tcp_test.go
utils.go [new file with mode: 0644]

index 5b23dad71909254838d72c49f7f202c4d87916a4..7460c293038a7d22a14f4cab4d0c38178a0fccab 100644 (file)
--- a/README.md
+++ b/README.md
@@ -45,8 +45,7 @@ modules:
         - "Download the latest version here"
       tls_config:
         insecure_skip_verify: false
-      protocol: "tcp" # accepts "tcp/tcp4/tcp6", defaults to "tcp"
-      preferred_ip_protocol: "ip4" # used for "tcp", defaults to "ip6"
+      preferred_ip_protocol: "ip4" # defaults to "ip6"
   http_post_2xx:
     prober: http
     timeout: 5s
@@ -55,11 +54,9 @@ modules:
       headers:
         Content-Type: application/json
       body: '{}'
-  tcp_connect_v4_example:
+  tcp_connect_example:
     prober: tcp
     timeout: 5s
-    tcp:
-      protocol: "tcp4"
   irc_banner_example:
     prober: tcp
     timeout: 5s
@@ -74,7 +71,6 @@ modules:
     prober: icmp
     timeout: 5s
     icmp:
-      protocol: "icmp"
       preferred_ip_protocol: "ip4"
   dns_udp_example:
     prober: dns
@@ -98,8 +94,8 @@ modules:
   dns_tcp_example:
     prober: dns
     dns:
-      protocol: "tcp" # accepts "tcp/tcp4/tcp6/udp/udp4/udp6", defaults to "udp"
-      preferred_ip_protocol: "ip4" # used for "udp/tcp", defaults to "ip6"
+      protocol: "tcp" # defaults to "udp"
+      preferred_ip_protocol: "ip4" #  defaults to "ip6"
       query_name: "www.prometheus.io"
 ```
 
diff --git a/dns.go b/dns.go
index 6c8f56f30813f474de764e1a1434865df50f99cc..9d9de2451d176732c6e215e4fadc244f5f09b720 100644 (file)
--- a/dns.go
+++ b/dns.go
@@ -15,7 +15,6 @@ package main
 
 import (
        "net"
-       "net/http"
        "regexp"
 
        "github.com/miekg/dns"
@@ -81,13 +80,9 @@ func validRcode(rcode int, valid []string) bool {
        return false
 }
 
-func probeDNS(target string, w http.ResponseWriter, module Module, registry *prometheus.Registry) bool {
+func probeDNS(target string, module Module, registry *prometheus.Registry) bool {
        var numAnswer, numAuthority, numAdditional int
-       var dialProtocol, fallbackProtocol string
-       probeIPProtocolGauge := prometheus.NewGauge(prometheus.GaugeOpts{
-               Name: "probe_ip_protocol",
-               Help: "Specifies whether probe ip protocl is IP4 or IP6",
-       })
+       var dialProtocol string
        probeDNSAnswerRRSGauge := prometheus.NewGauge(prometheus.GaugeOpts{
                Name: "probe_dns_answer_rrs",
                Help: "Returns number of entries in the answer resource record list",
@@ -100,7 +95,6 @@ func probeDNS(target string, w http.ResponseWriter, module Module, registry *pro
                Name: "probe_dns_additional_rrs",
                Help: "Returns number of entries in the additional resource record list",
        })
-       registry.MustRegister(probeIPProtocolGauge)
        registry.MustRegister(probeDNSAnswerRRSGauge)
        registry.MustRegister(probeDNSAuthorityRRSGauge)
        registry.MustRegister(probeDNSAdditionalRRSGauge)
@@ -113,41 +107,29 @@ func probeDNS(target string, w http.ResponseWriter, module Module, registry *pro
                probeDNSAdditionalRRSGauge.Set(float64(numAdditional))
        }()
 
-       if module.DNS.Protocol == "" {
-               module.DNS.Protocol = "udp"
-       }
+       var ip *net.IPAddr
+       var err error
 
-       if (module.DNS.Protocol == "tcp" || module.DNS.Protocol == "udp") && module.DNS.PreferredIPProtocol == "" {
-               module.DNS.PreferredIPProtocol = "ip6"
-       }
-       if module.DNS.PreferredIPProtocol == "ip6" {
-               fallbackProtocol = "ip4"
-       } else {
-               fallbackProtocol = "ip6"
+       if module.DNS.TransportProtocol == "" {
+               module.DNS.TransportProtocol = "udp"
        }
 
-       dialProtocol = module.DNS.Protocol
-       if module.DNS.Protocol == "udp" || module.DNS.Protocol == "tcp" {
-               targetAddress, _, _ := net.SplitHostPort(target)
-               ip, err := net.ResolveIPAddr(module.DNS.PreferredIPProtocol, targetAddress)
+       if module.DNS.TransportProtocol == "udp" || module.DNS.TransportProtocol == "tcp" {
+               targetAddr, _, _ := net.SplitHostPort(target)
+               ip, err = chooseProtocol(module.DNS.PreferredIPProtocol, targetAddr, registry)
                if err != nil {
-                       ip, err = net.ResolveIPAddr(fallbackProtocol, targetAddress)
-                       if err != nil {
-                               return false
-                       }
-               }
-
-               if ip.IP.To4() == nil {
-                       dialProtocol = module.DNS.Protocol + "6"
-               } else {
-                       dialProtocol = module.DNS.Protocol + "4"
+                       log.Error(err)
+                       return false
                }
+       } else {
+               log.Errorf("Configuration error: Expected transport protocol udp or tcp, got %s", module.DNS.TransportProtocol)
+               return false
        }
 
-       if dialProtocol[len(dialProtocol)-1] == '6' {
-               probeIPProtocolGauge.Set(6)
+       if ip.IP.To4() == nil {
+               dialProtocol = module.DNS.TransportProtocol + "6"
        } else {
-               probeIPProtocolGauge.Set(4)
+               dialProtocol = module.DNS.TransportProtocol + "4"
        }
 
        client := new(dns.Client)
@@ -166,7 +148,6 @@ func probeDNS(target string, w http.ResponseWriter, module Module, registry *pro
 
        msg := new(dns.Msg)
        msg.SetQuestion(dns.Fqdn(module.DNS.QueryName), qt)
-
        response, _, err := client.Exchange(msg, target)
        if err != nil {
                log.Warnf("Error while sending a DNS query: %s", err)
index aaafa5e773ef858b7766cc401524dcacc7f746d4..ff36a4f367f89cf35555aaa07f5124a31fc1a4c5 100644 (file)
@@ -15,7 +15,6 @@ package main
 
 import (
        "net"
-       "net/http/httptest"
        "runtime"
        "testing"
        "time"
@@ -116,11 +115,10 @@ func TestRecursiveDNSResponse(t *testing.T) {
                defer server.Shutdown()
 
                for i, test := range tests {
-                       test.Probe.Protocol = protocol
-                       recorder := httptest.NewRecorder()
+                       test.Probe.TransportProtocol = protocol
                        registry := prometheus.NewPedanticRegistry()
                        registry.Gather()
-                       result := probeDNS(addr.String(), recorder, Module{Timeout: time.Second, DNS: test.Probe}, registry)
+                       result := probeDNS(addr.String(), Module{Timeout: time.Second, DNS: test.Probe}, registry)
                        if result != test.ShouldSucceed {
                                t.Fatalf("Test %d had unexpected result: %v", i, result)
                        }
@@ -243,10 +241,9 @@ func TestAuthoritativeDNSResponse(t *testing.T) {
                defer server.Shutdown()
 
                for i, test := range tests {
-                       test.Probe.Protocol = protocol
-                       recorder := httptest.NewRecorder()
+                       test.Probe.TransportProtocol = protocol
                        registry := prometheus.NewRegistry()
-                       result := probeDNS(addr.String(), recorder, Module{Timeout: time.Second, DNS: test.Probe}, registry)
+                       result := probeDNS(addr.String(), Module{Timeout: time.Second, DNS: test.Probe}, registry)
                        if result != test.ShouldSucceed {
                                t.Fatalf("Test %d had unexpected result: %v", i, result)
                        }
@@ -300,10 +297,9 @@ func TestServfailDNSResponse(t *testing.T) {
                defer server.Shutdown()
 
                for i, test := range tests {
-                       test.Probe.Protocol = protocol
-                       recorder := httptest.NewRecorder()
+                       test.Probe.TransportProtocol = protocol
                        registry := prometheus.NewRegistry()
-                       result := probeDNS(addr.String(), recorder, Module{Timeout: time.Second, DNS: test.Probe}, registry)
+                       result := probeDNS(addr.String(), Module{Timeout: time.Second, DNS: test.Probe}, registry)
                        if result != test.ShouldSucceed {
                                t.Fatalf("Test %d had unexpected result: %v", i, result)
                        }
@@ -343,13 +339,13 @@ func TestDNSProtocol(t *testing.T) {
                module := Module{
                        Timeout: time.Second,
                        DNS: DNSProbe{
-                               QueryName: "example.com",
-                               Protocol:  protocol + "4",
+                               QueryName:           "example.com",
+                               TransportProtocol:   protocol,
+                               PreferredIPProtocol: "ip4",
                        },
                }
-               recorder := httptest.NewRecorder()
                registry := prometheus.NewRegistry()
-               result := probeDNS(net.JoinHostPort("localhost", port), recorder, module, registry)
+               result := probeDNS(net.JoinHostPort("localhost", port), module, registry)
                if !result {
                        t.Fatalf("DNS protocol: \"%v4\" connection test failed, expected success.", protocol)
                }
@@ -367,13 +363,13 @@ func TestDNSProtocol(t *testing.T) {
                module = Module{
                        Timeout: time.Second,
                        DNS: DNSProbe{
-                               QueryName: "example.com",
-                               Protocol:  protocol + "6",
+                               QueryName:           "example.com",
+                               TransportProtocol:   protocol,
+                               PreferredIPProtocol: "ip6",
                        },
                }
-               recorder = httptest.NewRecorder()
                registry = prometheus.NewRegistry()
-               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module, registry)
+               result = probeDNS(net.JoinHostPort("localhost", port), module, registry)
                if !result {
                        t.Fatalf("DNS protocol: \"%v6\" connection test failed, expected success.", protocol)
                }
@@ -391,13 +387,12 @@ func TestDNSProtocol(t *testing.T) {
                        Timeout: time.Second,
                        DNS: DNSProbe{
                                QueryName:           "example.com",
-                               Protocol:            protocol,
+                               TransportProtocol:   protocol,
                                PreferredIPProtocol: "ip6",
                        },
                }
-               recorder = httptest.NewRecorder()
                registry = prometheus.NewRegistry()
-               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module, registry)
+               result = probeDNS(net.JoinHostPort("localhost", port), module, registry)
                if !result {
                        t.Fatalf("DNS protocol: \"%v\", preferred \"ip6\" connection test failed, expected success.", protocol)
                }
@@ -415,13 +410,12 @@ func TestDNSProtocol(t *testing.T) {
                        Timeout: time.Second,
                        DNS: DNSProbe{
                                QueryName:           "example.com",
-                               Protocol:            protocol,
+                               TransportProtocol:   protocol,
                                PreferredIPProtocol: "ip4",
                        },
                }
-               recorder = httptest.NewRecorder()
                registry = prometheus.NewRegistry()
-               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module, registry)
+               result = probeDNS(net.JoinHostPort("localhost", port), module, registry)
                if !result {
                        t.Fatalf("DNS protocol: \"%v\", preferred \"ip4\" connection test failed, expected success.", protocol)
                }
@@ -439,13 +433,12 @@ func TestDNSProtocol(t *testing.T) {
                module = Module{
                        Timeout: time.Second,
                        DNS: DNSProbe{
-                               QueryName: "example.com",
-                               Protocol:  protocol,
+                               QueryName:         "example.com",
+                               TransportProtocol: protocol,
                        },
                }
-               recorder = httptest.NewRecorder()
                registry = prometheus.NewRegistry()
-               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module, registry)
+               result = probeDNS(net.JoinHostPort("localhost", port), module, registry)
                if !result {
                        t.Fatalf("DNS protocol: \"%v\" connection test failed, expected success.", protocol)
                }
@@ -466,9 +459,8 @@ func TestDNSProtocol(t *testing.T) {
                                QueryName: "example.com",
                        },
                }
-               recorder = httptest.NewRecorder()
                registry = prometheus.NewRegistry()
-               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module, registry)
+               result = probeDNS(net.JoinHostPort("localhost", port), module, registry)
                if protocol == "udp" {
                        if !result {
                                t.Fatalf("DNS test connection with protocol %s failed, expected success.", protocol)
diff --git a/http.go b/http.go
index 140139d9189d68ff516c55b0c4adfdac1952b830..d543d0f8e3af6c03d03a5a7dc71c1dfce95ed4a3 100644 (file)
--- a/http.go
+++ b/http.go
@@ -56,9 +56,9 @@ func matchRegularExpressions(reader io.Reader, config HTTPProbe) bool {
        return true
 }
 
-func probeHTTP(target string, w http.ResponseWriter, module Module, registry *prometheus.Registry) (success bool) {
+func probeHTTP(target string, module Module, registry *prometheus.Registry) (success bool) {
        var redirects int
-       var dialProtocol, fallbackProtocol string
+       var dialProtocol string
 
        var (
                contentLengthGauge = prometheus.NewGauge(prometheus.GaugeOpts{
@@ -81,11 +81,6 @@ func probeHTTP(target string, w http.ResponseWriter, module Module, registry *pr
                        Help: "Response HTTP status code",
                })
 
-               probeIPProtocolGauge = prometheus.NewGauge(prometheus.GaugeOpts{
-                       Name: "probe_ip_protocol",
-                       Help: "Specifies whether probe ip protocl is IP4 or IP6",
-               })
-
                probeSSLEarliestCertExpiryGauge = prometheus.NewGauge(prometheus.GaugeOpts{
                        Name: "probe_ssl_earliest_cert_expiry",
                        Help: "Returns earliest SSL cert expiry in unixtime",
@@ -96,56 +91,32 @@ func probeHTTP(target string, w http.ResponseWriter, module Module, registry *pr
        registry.MustRegister(redirectsGauge)
        registry.MustRegister(isSSLGauge)
        registry.MustRegister(statusCodeGauge)
-       registry.MustRegister(probeIPProtocolGauge)
 
        config := module.HTTP
 
-       if module.HTTP.Protocol == "" {
-               module.HTTP.Protocol = "tcp"
-       }
-
-       if module.HTTP.Protocol == "tcp" && module.HTTP.PreferredIPProtocol == "" {
-               module.HTTP.PreferredIPProtocol = "ip6"
-       }
-       if module.HTTP.PreferredIPProtocol == "ip6" {
-               fallbackProtocol = "ip4"
-       } else {
-               fallbackProtocol = "ip6"
-       }
        if !strings.HasPrefix(target, "http://") && !strings.HasPrefix(target, "https://") {
                target = "http://" + target
        }
 
-       dialProtocol = module.HTTP.Protocol
-       if module.HTTP.Protocol == "tcp" {
-               targetURL, err := url.Parse(target)
-               if err != nil {
-                       return false
-               }
-               targetHost, _, err := net.SplitHostPort(targetURL.Host)
-               // If split fails, assuming it's a hostname without port part
-               if err != nil {
-                       targetHost = targetURL.Host
-               }
-               ip, err := net.ResolveIPAddr(module.HTTP.PreferredIPProtocol, targetHost)
-               if err != nil {
-                       ip, err = net.ResolveIPAddr(fallbackProtocol, targetHost)
-                       if err != nil {
-                               return false
-                       }
-               }
+       targetURL, err := url.Parse(target)
+       if err != nil {
+               return false
+       }
+       targetHost, targetPort, err := net.SplitHostPort(targetURL.Host)
+       // If split fails, assuming it's a hostname without port part.
+       if err != nil {
+               targetHost = targetURL.Host
+       }
 
-               if ip.IP.To4() == nil {
-                       dialProtocol = "tcp6"
-               } else {
-                       dialProtocol = "tcp4"
-               }
+       ip, err := chooseProtocol(module.HTTP.PreferredIPProtocol, targetHost, registry)
+       if err != nil {
+               return false
        }
 
-       if dialProtocol == "tcp6" {
-               probeIPProtocolGauge.Set(6)
+       if ip.IP.To4() == nil {
+               dialProtocol = "tcp6"
        } else {
-               probeIPProtocolGauge.Set(4)
+               dialProtocol = "tcp4"
        }
 
        client := &http.Client{
@@ -180,6 +151,13 @@ func probeHTTP(target string, w http.ResponseWriter, module Module, registry *pr
        }
 
        request, err := http.NewRequest(config.Method, target, nil)
+       request.Host = targetURL.Host
+       if targetPort == "" {
+               targetURL.Host = ip.String()
+       } else {
+               targetURL.Host = net.JoinHostPort(ip.String(), targetPort)
+       }
+
        if err != nil {
                log.Errorf("Error creating request for target %s: %s", target, err)
                return
@@ -197,7 +175,6 @@ func probeHTTP(target string, w http.ResponseWriter, module Module, registry *pr
        if config.Body != "" {
                request.Body = ioutil.NopCloser(strings.NewReader(config.Body))
        }
-
        resp, err := client.Do(request)
 
        // Err won't be nil if redirects were turned off. See https://github.com/golang/go/issues/3795
@@ -239,6 +216,5 @@ func probeHTTP(target string, w http.ResponseWriter, module Module, registry *pr
        statusCodeGauge.Set(float64(resp.StatusCode))
        contentLengthGauge.Set(float64(resp.ContentLength))
        redirectsGauge.Set(float64(redirects))
-
        return
 }
index c08f7e46220c14f016d27a5e2bad27ffab2c583a..bb8063a7f0ee8411826379ab2865761d08b4360b 100644 (file)
@@ -49,7 +49,7 @@ func TestHTTPStatusCodes(t *testing.T) {
                defer ts.Close()
                registry := prometheus.NewRegistry()
                recorder := httptest.NewRecorder()
-               result := probeHTTP(ts.URL, recorder,
+               result := probeHTTP(ts.URL,
                        Module{Timeout: time.Second, HTTP: HTTPProbe{ValidStatusCodes: test.ValidStatusCodes}}, registry)
                body := recorder.Body.String()
                if result != test.ShouldSucceed {
@@ -69,7 +69,7 @@ func TestRedirectFollowed(t *testing.T) {
        // Follow redirect, should succeed with 200.
        recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder, Module{Timeout: time.Second, HTTP: HTTPProbe{}}, registry)
+       result := probeHTTP(ts.URL, Module{Timeout: time.Second, HTTP: HTTPProbe{}}, registry)
        body := recorder.Body.String()
        if !result {
                t.Fatalf("Redirect test failed unexpectedly, got %s", body)
@@ -94,7 +94,7 @@ func TestRedirectNotFollowed(t *testing.T) {
        // Follow redirect, should succeed with 200.
        recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder,
+       result := probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{NoFollowRedirects: true, ValidStatusCodes: []int{302}}}, registry)
        body := recorder.Body.String()
        if !result {
@@ -113,7 +113,7 @@ func TestPost(t *testing.T) {
 
        recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder,
+       result := probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{Method: "POST"}}, registry)
        body := recorder.Body.String()
        if !result {
@@ -128,7 +128,7 @@ func TestFailIfNotSSL(t *testing.T) {
 
        recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder,
+       result := probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{FailIfNotSSL: true}}, registry)
        body := recorder.Body.String()
        if result {
@@ -152,7 +152,7 @@ func TestFailIfMatchesRegexp(t *testing.T) {
 
        recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder,
+       result := probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{FailIfMatchesRegexp: []string{"could not connect to database"}}}, registry)
        body := recorder.Body.String()
        if result {
@@ -166,7 +166,7 @@ func TestFailIfMatchesRegexp(t *testing.T) {
 
        recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeHTTP(ts.URL, recorder,
+       result = probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{FailIfMatchesRegexp: []string{"could not connect to database"}}}, registry)
        body = recorder.Body.String()
        if !result {
@@ -182,7 +182,7 @@ func TestFailIfMatchesRegexp(t *testing.T) {
 
        recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeHTTP(ts.URL, recorder,
+       result = probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{FailIfMatchesRegexp: []string{"could not connect to database", "internal error"}}}, registry)
        body = recorder.Body.String()
        if result {
@@ -196,7 +196,7 @@ func TestFailIfMatchesRegexp(t *testing.T) {
 
        recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeHTTP(ts.URL, recorder,
+       result = probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{FailIfMatchesRegexp: []string{"could not connect to database", "internal error"}}}, registry)
        body = recorder.Body.String()
        if !result {
@@ -212,7 +212,7 @@ func TestFailIfNotMatchesRegexp(t *testing.T) {
 
        recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder,
+       result := probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{FailIfNotMatchesRegexp: []string{"Download the latest version here"}}}, registry)
        body := recorder.Body.String()
        if result {
@@ -226,7 +226,7 @@ func TestFailIfNotMatchesRegexp(t *testing.T) {
 
        recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeHTTP(ts.URL, recorder,
+       result = probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{FailIfNotMatchesRegexp: []string{"Download the latest version here"}}}, registry)
        body = recorder.Body.String()
        if !result {
@@ -242,7 +242,7 @@ func TestFailIfNotMatchesRegexp(t *testing.T) {
 
        recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeHTTP(ts.URL, recorder,
+       result = probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{FailIfNotMatchesRegexp: []string{"Download the latest version here", "Copyright 2015"}}}, registry)
        body = recorder.Body.String()
        if result {
@@ -256,7 +256,7 @@ func TestFailIfNotMatchesRegexp(t *testing.T) {
 
        recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeHTTP(ts.URL, recorder,
+       result = probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{FailIfNotMatchesRegexp: []string{"Download the latest version here", "Copyright 2015"}}}, registry)
        body = recorder.Body.String()
        if !result {
@@ -285,9 +285,8 @@ func TestHTTPHeaders(t *testing.T) {
                w.WriteHeader(http.StatusOK)
        }))
        defer ts.Close()
-       recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder, Module{Timeout: time.Second, HTTP: HTTPProbe{
+       result := probeHTTP(ts.URL, Module{Timeout: time.Second, HTTP: HTTPProbe{
                Headers: headers,
        }}, registry)
        if !result {
@@ -302,7 +301,7 @@ func TestFailIfSelfSignedCA(t *testing.T) {
 
        recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder,
+       result := probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{
                        TLSConfig: config.TLSConfig{InsecureSkipVerify: false},
                }}, registry)
@@ -327,7 +326,7 @@ func TestSucceedIfSelfSignedCA(t *testing.T) {
 
        recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder,
+       result := probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{
                        TLSConfig: config.TLSConfig{InsecureSkipVerify: true},
                }}, registry)
@@ -352,7 +351,7 @@ func TestTLSConfigIsIgnoredForPlainHTTP(t *testing.T) {
 
        recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeHTTP(ts.URL, recorder,
+       result := probeHTTP(ts.URL,
                Module{Timeout: time.Second, HTTP: HTTPProbe{
                        TLSConfig: config.TLSConfig{InsecureSkipVerify: false},
                }}, registry)
diff --git a/icmp.go b/icmp.go
index b8f501d724a1d61beee2510ba502215f022cf823..6a61070a28b60bda33687c196cf8aa1f91e05ee4 100644 (file)
--- a/icmp.go
+++ b/icmp.go
@@ -16,7 +16,6 @@ package main
 import (
        "bytes"
        "net"
-       "net/http"
        "os"
        "sync"
        "time"
@@ -41,71 +40,29 @@ func getICMPSequence() uint16 {
        return icmpSequence
 }
 
-func probeICMP(target string, w http.ResponseWriter, module Module, registry *prometheus.Registry) (success bool) {
+func probeICMP(target string, module Module, registry *prometheus.Registry) (success bool) {
        var (
-               socket               *icmp.PacketConn
-               requestType          icmp.Type
-               replyType            icmp.Type
-               fallbackProtocol     string
-               probeIPProtocolGauge = prometheus.NewGauge(prometheus.GaugeOpts{
-                       Name: "probe_ip_protocol",
-                       Help: "Specifies whether probe ip protocl is IP4 or IP6",
-               })
-               probeDNSLookupTimeSeconds = prometheus.NewGauge(prometheus.GaugeOpts{
-                       Name: "probe_dns_lookup_time_seconds",
-                       Help: "Returns the time taken for probe dns lookup in seconds",
-               })
+               socket      *icmp.PacketConn
+               requestType icmp.Type
+               replyType   icmp.Type
        )
 
-       registry.MustRegister(probeIPProtocolGauge)
-       registry.MustRegister(probeDNSLookupTimeSeconds)
-
        deadline := time.Now().Add(module.Timeout)
 
-       // Defaults to IPv4 to be compatible with older versions
-       if module.ICMP.Protocol == "" {
-               module.ICMP.Protocol = "icmp"
-       }
-
-       // In case of ICMP prefer IPv6 by default
-       if module.ICMP.Protocol == "icmp" && module.ICMP.PreferredIPProtocol == "" {
-               module.ICMP.PreferredIPProtocol = "ip6"
-       }
-
-       if module.ICMP.Protocol == "icmp4" {
-               module.ICMP.PreferredIPProtocol = "ip4"
-               fallbackProtocol = ""
-       } else if module.ICMP.Protocol == "icmp6" {
-               module.ICMP.PreferredIPProtocol = "ip6"
-               fallbackProtocol = ""
-       } else if module.ICMP.PreferredIPProtocol == "ip6" {
-               fallbackProtocol = "ip4"
-       } else {
-               fallbackProtocol = "ip6"
-       }
-
-       resolveStart := time.Now()
-       ip, err := net.ResolveIPAddr(module.ICMP.PreferredIPProtocol, target)
-       if err != nil && fallbackProtocol != "" {
-               ip, err = net.ResolveIPAddr(fallbackProtocol, target)
-       }
-       probeDNSLookupTimeSeconds.Add(time.Since(resolveStart).Seconds())
-
+       ip, err := chooseProtocol(module.ICMP.PreferredIPProtocol, target, registry)
        if err != nil {
                log.Warnf("Error resolving address %s: %s", target, err)
-               return
+               return false
        }
 
        if ip.IP.To4() == nil {
                requestType = ipv6.ICMPTypeEchoRequest
                replyType = ipv6.ICMPTypeEchoReply
                socket, err = icmp.ListenPacket("ip6:ipv6-icmp", "::")
-               probeIPProtocolGauge.Set(6)
        } else {
                requestType = ipv4.ICMPTypeEcho
                replyType = ipv4.ICMPTypeEchoReply
                socket, err = icmp.ListenPacket("ip4:icmp", "0.0.0.0")
-               probeIPProtocolGauge.Set(4)
        }
 
        if err != nil {
@@ -129,7 +86,7 @@ func probeICMP(target string, w http.ResponseWriter, module Module, registry *pr
                log.Errorf("Error marshalling packet for %s: %s", target, err)
                return
        }
-       if _, err := socket.WriteTo(wb, ip); err != nil {
+       if _, err = socket.WriteTo(wb, ip); err != nil {
                log.Warnf("Error writing to socket for %s: %s", target, err)
                return
        }
diff --git a/main.go b/main.go
index c67e2804e2ec72f4b6ed1972c06d89329b472e73..09c004b8ad107b2cce1cb9791a629f187cc5da1c 100644 (file)
--- a/main.go
+++ b/main.go
@@ -55,6 +55,7 @@ type Module struct {
 type HTTPProbe struct {
        // Defaults to 2xx.
        ValidStatusCodes       []int             `yaml:"valid_status_codes"`
+       PreferredIPProtocol    string            `yaml:"preferred_ip_protocol"`
        NoFollowRedirects      bool              `yaml:"no_follow_redirects"`
        FailIfSSL              bool              `yaml:"fail_if_ssl"`
        FailIfNotSSL           bool              `yaml:"fail_if_not_ssl"`
@@ -63,8 +64,6 @@ type HTTPProbe struct {
        FailIfMatchesRegexp    []string          `yaml:"fail_if_matches_regexp"`
        FailIfNotMatchesRegexp []string          `yaml:"fail_if_not_matches_regexp"`
        TLSConfig              config.TLSConfig  `yaml:"tls_config"`
-       Protocol               string            `yaml:"protocol"`              // Defaults to "tcp".
-       PreferredIPProtocol    string            `yaml:"preferred_ip_protocol"` // Defaults to "ip6".
        Body                   string            `yaml:"body"`
 }
 
@@ -74,27 +73,25 @@ type QueryResponse struct {
 }
 
 type TCPProbe struct {
+       PreferredIPProtocol string           `yaml:"preferred_ip_protocol"`
        QueryResponse       []QueryResponse  `yaml:"query_response"`
        TLS                 bool             `yaml:"tls"`
        TLSConfig           config.TLSConfig `yaml:"tls_config"`
-       Protocol            string           `yaml:"protocol"`              // Defaults to "tcp".
-       PreferredIPProtocol string           `yaml:"preferred_ip_protocol"` // Defaults to "ip6".
 }
 
 type ICMPProbe struct {
-       Protocol            string `yaml:"protocol"`              // Defaults to "icmp4".
        PreferredIPProtocol string `yaml:"preferred_ip_protocol"` // Defaults to "ip6".
 }
 
 type DNSProbe struct {
-       Protocol            string         `yaml:"protocol"` // Defaults to "udp".
+       PreferredIPProtocol string         `yaml:"preferred_ip_protocol"`
+       TransportProtocol   string         `yaml:"transport_protocol"`
        QueryName           string         `yaml:"query_name"`
        QueryType           string         `yaml:"query_type"`   // Defaults to ANY.
        ValidRcodes         []string       `yaml:"valid_rcodes"` // Defaults to NOERROR.
        ValidateAnswer      DNSRRValidator `yaml:"validate_answer_rrs"`
        ValidateAuthority   DNSRRValidator `yaml:"validate_authority_rrs"`
        ValidateAdditional  DNSRRValidator `yaml:"validate_additional_rrs"`
-       PreferredIPProtocol string         `yaml:"preferred_ip_protocol"` // Defaults to "ip6".
 }
 
 type DNSRRValidator struct {
@@ -102,7 +99,7 @@ type DNSRRValidator struct {
        FailIfNotMatchesRegexp []string `yaml:"fail_if_not_matches_regexp"`
 }
 
-var Probers = map[string]func(string, http.ResponseWriter, Module, *prometheus.Registry) bool{
+var Probers = map[string]func(string, Module, *prometheus.Registry) bool{
        "http": probeHTTP,
        "tcp":  probeTCP,
        "icmp": probeICMP,
@@ -166,7 +163,7 @@ func probeHandler(w http.ResponseWriter, r *http.Request, conf *Config) {
        registry := prometheus.NewRegistry()
        registry.MustRegister(probeSuccessGauge)
        registry.MustRegister(probeDurationGauge)
-       success := prober(target, w, module, registry)
+       success := prober(target, module, registry)
        probeDurationGauge.Set(time.Since(start).Seconds())
        if success {
                probeSuccessGauge.Set(1)
diff --git a/tcp.go b/tcp.go
index 135e8eb68d62226632089a35e84c591d399a5eb1..ba6cde3f8ec723ee4bc41c3764f674ba01d2d9b7 100644 (file)
--- a/tcp.go
+++ b/tcp.go
@@ -18,7 +18,6 @@ import (
        "crypto/tls"
        "fmt"
        "net"
-       "net/http"
        "regexp"
        "time"
 
@@ -26,69 +25,45 @@ import (
        "github.com/prometheus/common/log"
 )
 
-func dialTCP(target string, w http.ResponseWriter, module Module, protocolProbeGauge prometheus.Gauge) (net.Conn, error) {
-       var dialProtocol, fallbackProtocol string
-
+func dialTCP(target string, module Module, registry *prometheus.Registry) (net.Conn, error) {
+       var dialProtocol, dialTarget string
        dialer := &net.Dialer{Timeout: module.Timeout}
-       if module.TCP.Protocol == "" {
-               module.TCP.Protocol = "tcp"
-       }
-       if module.TCP.Protocol == "tcp" && module.TCP.PreferredIPProtocol == "" {
-               module.TCP.PreferredIPProtocol = "ip6"
-       }
-       if module.TCP.PreferredIPProtocol == "ip6" {
-               fallbackProtocol = "ip4"
-       } else {
-               fallbackProtocol = "ip6"
-       }
 
-       dialProtocol = module.TCP.Protocol
-       if module.TCP.Protocol == "tcp" {
-               targetAddress, _, err := net.SplitHostPort(target)
-               ip, err := net.ResolveIPAddr(module.TCP.PreferredIPProtocol, targetAddress)
-               if err != nil {
-                       ip, err = net.ResolveIPAddr(fallbackProtocol, targetAddress)
-                       if err != nil {
-                               return nil, err
-                       }
-               }
+       targetAddress, port, err := net.SplitHostPort(target)
+       if err != nil {
+               return nil, err
+       }
 
-               if ip.IP.To4() == nil {
-                       dialProtocol = "tcp6"
-               } else {
-                       dialProtocol = "tcp4"
-               }
+       ip, err := chooseProtocol(module.TCP.PreferredIPProtocol, targetAddress, registry)
+       if err != nil {
+               return nil, err
        }
 
-       if dialProtocol == "tcp6" {
-               protocolProbeGauge.Set(6)
+       if ip.IP.To4() == nil {
+               dialProtocol = "tcp6"
        } else {
-               protocolProbeGauge.Set(4)
+               dialProtocol = "tcp4"
        }
+       dialTarget = net.JoinHostPort(ip.String(), port)
 
        if !module.TCP.TLS {
-               return dialer.Dial(dialProtocol, target)
+               return dialer.Dial(dialProtocol, dialTarget)
        }
        config, err := module.TCP.TLSConfig.GenerateConfig()
        if err != nil {
                return nil, err
        }
-       return tls.DialWithDialer(dialer, dialProtocol, target, config)
+       return tls.DialWithDialer(dialer, dialProtocol, dialTarget, config)
 }
 
-func probeTCP(target string, w http.ResponseWriter, module Module, registry *prometheus.Registry) bool {
-       probeIPProtocolGauge := prometheus.NewGauge(prometheus.GaugeOpts{
-               Name: "probe_ip_protocol",
-               Help: "Specifies whether probe ip protocol is IP4 or IP6",
-       })
+func probeTCP(target string, module Module, registry *prometheus.Registry) bool {
        probeSSLEarliestCertExpiry := prometheus.NewGauge(prometheus.GaugeOpts{
                Name: "probe_ssl_earliest_cert_expiry",
                Help: "Returns earliest SSL cert expiry date",
        })
-       registry.MustRegister(probeIPProtocolGauge)
        registry.MustRegister(probeSSLEarliestCertExpiry)
        deadline := time.Now().Add(module.Timeout)
-       conn, err := dialTCP(target, w, module, probeIPProtocolGauge)
+       conn, err := dialTCP(target, module, registry)
        if err != nil {
                return false
        }
index 2968e9634850df539fb9ce7aca6b1c9ff59f19ea..7add94c7f4f058008bd1cbd6ef9e936b44c3dae7 100644 (file)
@@ -16,7 +16,6 @@ package main
 import (
        "fmt"
        "net"
-       "net/http/httptest"
        "runtime"
        "testing"
        "time"
@@ -40,9 +39,8 @@ func TestTCPConnection(t *testing.T) {
                conn.Close()
                ch <- struct{}{}
        }()
-       recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       if !probeTCP(ln.Addr().String(), recorder, Module{Timeout: time.Second}, registry) {
+       if !probeTCP(ln.Addr().String(), Module{Timeout: time.Second}, registry) {
                t.Fatalf("TCP module failed, expected success.")
        }
        <-ch
@@ -50,9 +48,8 @@ func TestTCPConnection(t *testing.T) {
 
 func TestTCPConnectionFails(t *testing.T) {
        // Invalid port number.
-       recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       if probeTCP(":0", recorder, Module{Timeout: time.Second}, registry) {
+       if probeTCP(":0", Module{Timeout: time.Second}, registry) {
                t.Fatalf("TCP module suceeded, expected failure.")
        }
 }
@@ -89,9 +86,8 @@ func TestTCPConnectionQueryResponseIRC(t *testing.T) {
                conn.Close()
                ch <- struct{}{}
        }()
-       recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       if !probeTCP(ln.Addr().String(), recorder, module, registry) {
+       if !probeTCP(ln.Addr().String(), module, registry) {
                t.Fatalf("TCP module failed, expected success.")
        }
        <-ch
@@ -110,7 +106,7 @@ func TestTCPConnectionQueryResponseIRC(t *testing.T) {
                ch <- struct{}{}
        }()
        registry = prometheus.NewRegistry()
-       if probeTCP(ln.Addr().String(), recorder, module, registry) {
+       if probeTCP(ln.Addr().String(), module, registry) {
                t.Fatalf("TCP module succeeded, expected failure.")
        }
        <-ch
@@ -148,9 +144,8 @@ func TestTCPConnectionQueryResponseMatching(t *testing.T) {
                conn.Close()
                ch <- version
        }()
-       recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       if !probeTCP(ln.Addr().String(), recorder, module, registry) {
+       if !probeTCP(ln.Addr().String(), module, registry) {
                t.Fatalf("TCP module failed, expected success.")
        }
        if got, want := <-ch, "OpenSSH_6.9p1"; got != want {
@@ -182,13 +177,12 @@ func TestTCPConnectionProtocol(t *testing.T) {
        module := Module{
                Timeout: time.Second,
                TCP: TCPProbe{
-                       Protocol: "tcp4",
+                       PreferredIPProtocol: "ip4",
                },
        }
 
-       recorder := httptest.NewRecorder()
        registry := prometheus.NewRegistry()
-       result := probeTCP(net.JoinHostPort("localhost", port), recorder, module, registry)
+       result := probeTCP(net.JoinHostPort("localhost", port), module, registry)
        if !result {
                t.Fatalf("TCP protocol: \"tcp4\" connection test failed, expected success.")
        }
@@ -204,14 +198,11 @@ func TestTCPConnectionProtocol(t *testing.T) {
        // Force IPv6
        module = Module{
                Timeout: time.Second,
-               TCP: TCPProbe{
-                       Protocol: "tcp6",
-               },
+               TCP:     TCPProbe{},
        }
 
-       recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module, registry)
+       result = probeTCP(net.JoinHostPort("localhost", port), module, registry)
        if !result {
                t.Fatalf("TCP protocol: \"tcp6\" connection test failed, expected success.")
        }
@@ -228,14 +219,12 @@ func TestTCPConnectionProtocol(t *testing.T) {
        module = Module{
                Timeout: time.Second,
                TCP: TCPProbe{
-                       Protocol:            "tcp",
                        PreferredIPProtocol: "ip4",
                },
        }
 
-       recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module, registry)
+       result = probeTCP(net.JoinHostPort("localhost", port), module, registry)
        if !result {
                t.Fatalf("TCP protocol: \"tcp\", prefer: \"ip4\" connection test failed, expected success.")
        }
@@ -252,14 +241,12 @@ func TestTCPConnectionProtocol(t *testing.T) {
        module = Module{
                Timeout: time.Second,
                TCP: TCPProbe{
-                       Protocol:            "tcp",
                        PreferredIPProtocol: "ip6",
                },
        }
 
-       recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module, registry)
+       result = probeTCP(net.JoinHostPort("localhost", port), module, registry)
        if !result {
                t.Fatalf("TCP protocol: \"tcp\", prefer: \"ip6\" connection test failed, expected success.")
        }
@@ -275,14 +262,11 @@ func TestTCPConnectionProtocol(t *testing.T) {
        // Prefer nothing
        module = Module{
                Timeout: time.Second,
-               TCP: TCPProbe{
-                       Protocol: "tcp",
-               },
+               TCP:     TCPProbe{},
        }
 
-       recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module, registry)
+       result = probeTCP(net.JoinHostPort("localhost", port), module, registry)
        if !result {
                t.Fatalf("TCP protocol: \"tcp\" connection test failed, expected success.")
        }
@@ -301,9 +285,8 @@ func TestTCPConnectionProtocol(t *testing.T) {
                TCP:     TCPProbe{},
        }
 
-       recorder = httptest.NewRecorder()
        registry = prometheus.NewRegistry()
-       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module, registry)
+       result = probeTCP(net.JoinHostPort("localhost", port), module, registry)
        if !result {
                t.Fatalf("TCP connection test with protocol unspecified failed, expected success.")
        }
diff --git a/utils.go b/utils.go
new file mode 100644 (file)
index 0000000..c885c7f
--- /dev/null
+++ b/utils.go
@@ -0,0 +1,60 @@
+package main
+
+import (
+       "net"
+       "time"
+
+       "github.com/prometheus/client_golang/prometheus"
+)
+
+// Returns the preferedIPProtocol, the dialProtocol, and sets the probeIPProtocolGauge.
+func chooseProtocol(preferredIPProtocol string, target string, registry *prometheus.Registry) (*net.IPAddr, error) {
+       var fallbackProtocol string
+
+       probeDNSLookupTimeSeconds := prometheus.NewGauge(prometheus.GaugeOpts{
+               Name: "probe_dns_lookup_time_seconds",
+               Help: "Returns the time taken for probe dns lookup in seconds",
+       })
+
+       probeIPProtocolGauge := prometheus.NewGauge(prometheus.GaugeOpts{
+               Name: "probe_ip_protocol",
+               Help: "Specifies whether probe ip protocol is IP4 or IP6",
+       })
+       registry.MustRegister(probeIPProtocolGauge)
+       registry.MustRegister(probeDNSLookupTimeSeconds)
+
+       if preferredIPProtocol == "ip6" || preferredIPProtocol == "" {
+               preferredIPProtocol = "ip6"
+               fallbackProtocol = "ip4"
+       } else {
+               preferredIPProtocol = "ip4"
+               fallbackProtocol = "ip6"
+       }
+
+       if preferredIPProtocol == "ip6" {
+               fallbackProtocol = "ip4"
+       } else {
+               fallbackProtocol = "ip6"
+       }
+
+       resolveStart := time.Now()
+
+       ip, err := net.ResolveIPAddr(preferredIPProtocol, target)
+       if err != nil {
+               ip, err = net.ResolveIPAddr(fallbackProtocol, target)
+               if err != nil {
+                       return ip, err
+               }
+       }
+
+       probeDNSLookupTimeSeconds.Add(time.Since(resolveStart).Seconds())
+
+       if ip.IP.To4() == nil {
+               probeIPProtocolGauge.Set(6)
+       } else {
+               probeIPProtocolGauge.Set(4)
+       }
+
+       return ip, nil
+
+}