IPv6 support for all modules (#64)
authorHasso Tepper <hasso.tepper@gmail.com>
Fri, 30 Sep 2016 16:29:52 +0000 (19:29 +0300)
committerBrian Brazil <brian-brazil@users.noreply.github.com>
Fri, 30 Sep 2016 16:29:52 +0000 (17:29 +0100)
* IPv6 and IP protocol preferrence support for all modules

Introduce protocol field for every module which can be used to force
probe to IPv4 (tcp4/udp4/icmp4) or IPv6 (tcp6/udp6/icmp6). In case
of tcp/udp/icmp both can be used in order of preferrence. IPv6 is
preferred by default, but it can be tuned with 'preferred_ip_protocol'
parameter.

Default for all modules is an automatic selection with IPv6 preferred.

ICMP code is mostly from Michael Stapelberg.

* Document protocol and preferred_ip_protocol parameters

* Protocol tests for DNS and TCP modules

README.md
dns.go
dns_test.go
http.go
icmp.go
main.go
tcp.go
tcp_test.go

index 84780715b92f1606850812347754c7da747a2e50..7ab99b60cb30459bb8302f0350f8f55deaf7704e 100644 (file)
--- a/README.md
+++ b/README.md
@@ -45,9 +45,14 @@ modules:
       - "Download the latest version here"
       tls_config:
         insecure_skip_verify: false
+      protocol: "tcp" # accepts "tcp/tcp4/tcp6", defaults to "tcp"
+      preferred_ip_protocol: "ip4" # used for "tcp", defaults to "ip6"
   tcp_connect:
     prober: tcp
     timeout: 5s
+    tcp:
+      protocol: "tcp"
+      preferred_ip_protocol: "ip4"
   pop3s_banner:
     prober: tcp
     tcp:
@@ -75,6 +80,9 @@ modules:
   icmp:
     prober: icmp
     timeout: 5s
+    icmp:
+      protocol: "icmp"
+      preferred_ip_protocol: "ip4"
   dns_udp:
     prober: dns
     timeout: 5s
@@ -97,7 +105,8 @@ modules:
   dns_tcp:
     prober: dns
     dns:
-      protocol: "tcp"  # can also be something like "udp4" or "tcp6"
+      protocol: "tcp" # accepts "tcp/tcp4/tcp6/udp/udp4/udp6", defaults to "udp"
+      preferred_ip_protocol: "ip4" # used for "udp/tcp", defaults to "ip6"
       query_name: "www.prometheus.io"
 ```
 
diff --git a/dns.go b/dns.go
index bf3e25bc4553a7aa7600bb992f556c5b2bb34a18..8559064dd05b5d5842fa7e34504383c8365bcda9 100644 (file)
--- a/dns.go
+++ b/dns.go
@@ -15,6 +15,7 @@ package main
 
 import (
        "fmt"
+       "net"
        "net/http"
        "regexp"
 
@@ -82,6 +83,7 @@ func validRcode(rcode int, valid []string) bool {
 
 func probeDNS(target string, w http.ResponseWriter, module Module) bool {
        var numAnswer, numAuthority, numAdditional int
+       var dialProtocol, fallbackProtocol string
        defer func() {
                // These metrics can be used to build additional alerting based on the number of replies.
                // They should be returned even in case of errors.
@@ -90,8 +92,45 @@ func probeDNS(target string, w http.ResponseWriter, module Module) bool {
                fmt.Fprintf(w, "probe_dns_additional_rrs %d\n", numAdditional)
        }()
 
+       if module.DNS.Protocol == "" {
+               module.DNS.Protocol = "udp"
+       }
+
+       if (module.DNS.Protocol == "tcp" || module.DNS.Protocol == "udp") && module.DNS.PreferredIpProtocol == "" {
+               module.DNS.PreferredIpProtocol = "ip6"
+       }
+       if module.DNS.PreferredIpProtocol == "ip6" {
+               fallbackProtocol = "ip4"
+       } else {
+               fallbackProtocol = "ip6"
+       }
+
+       dialProtocol = module.DNS.Protocol
+       if module.DNS.Protocol == "udp" || module.DNS.Protocol == "tcp" {
+               target_address, _, _ := net.SplitHostPort(target)
+               ip, err := net.ResolveIPAddr(module.DNS.PreferredIpProtocol, target_address)
+               if err != nil {
+                       ip, err = net.ResolveIPAddr(fallbackProtocol, target_address)
+                       if err != nil {
+                               return false
+                       }
+               }
+
+               if ip.IP.To4() == nil {
+                       dialProtocol = module.DNS.Protocol + "6"
+               } else {
+                       dialProtocol = module.DNS.Protocol + "4"
+               }
+       }
+
+       if dialProtocol[len(dialProtocol)-1] == '6' {
+               fmt.Fprintf(w, "probe_ip_protocol 6\n")
+       } else {
+               fmt.Fprintf(w, "probe_ip_protocol 4\n")
+       }
+
        client := new(dns.Client)
-       client.Net = module.DNS.Protocol
+       client.Net = dialProtocol
        client.Timeout = module.Timeout
 
        qt := dns.TypeANY
index 3f9002dc47f37752c891423215723686206e7f2d..8982f3eadbcc3ebcfe197777e4a86786cdfcfb85 100644 (file)
@@ -16,6 +16,7 @@ package main
 import (
        "net"
        "net/http/httptest"
+       "runtime"
        "strings"
        "testing"
        "time"
@@ -318,3 +319,138 @@ func TestServfailDNSResponse(t *testing.T) {
                }
        }
 }
+
+func TestDNSProtocol(t *testing.T) {
+       // This test assumes that listening "tcp" listens both IPv6 and IPv4 traffic and
+       // localhost resolves to both 127.0.0.1 and ::1. we must skip the test if either
+       // of these isn't true. This should be true for modern Linux systems.
+       if runtime.GOOS == "dragonfly" || runtime.GOOS == "openbsd" {
+               t.Skip("IPv6 socket isn't able to accept IPv4 traffic in the system.")
+       }
+       _, err := net.ResolveIPAddr("ip6", "localhost")
+       if err != nil {
+               t.Skip("\"localhost\" doesn't resolve to ::1.")
+       }
+
+       for _, protocol := range PROTOCOLS {
+               server, addr := startDNSServer(protocol, recursiveDNSHandler)
+               defer server.Shutdown()
+
+               _, port, _ := net.SplitHostPort(addr.String())
+
+               // Force IPv4
+               module := Module{
+                       Timeout: time.Second,
+                       DNS: DNSProbe{
+                               QueryName: "example.com",
+                               Protocol:  protocol + "4",
+                       },
+               }
+               recorder := httptest.NewRecorder()
+               result := probeDNS(net.JoinHostPort("localhost", port), recorder, module)
+               body := recorder.Body.String()
+               if !result {
+                       t.Fatalf("DNS protocol: \"%v4\" connection test failed, expected success.", protocol)
+               }
+               if !strings.Contains(body, "probe_ip_protocol 4\n") {
+                       t.Fatalf("Expected IPv4, got %s", body)
+               }
+
+               // Force IPv6
+               module = Module{
+                       Timeout: time.Second,
+                       DNS: DNSProbe{
+                               QueryName: "example.com",
+                               Protocol:  protocol + "6",
+                       },
+               }
+               recorder = httptest.NewRecorder()
+               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module)
+               body = recorder.Body.String()
+               if !result {
+                       t.Fatalf("DNS protocol: \"%v6\" connection test failed, expected success.", protocol)
+               }
+               if !strings.Contains(body, "probe_ip_protocol 6\n") {
+                       t.Fatalf("Expected IPv6, got %s", body)
+               }
+
+               // Prefer IPv6
+               module = Module{
+                       Timeout: time.Second,
+                       DNS: DNSProbe{
+                               QueryName:           "example.com",
+                               Protocol:            protocol,
+                               PreferredIpProtocol: "ip6",
+                       },
+               }
+               recorder = httptest.NewRecorder()
+               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module)
+               body = recorder.Body.String()
+               if !result {
+                       t.Fatalf("DNS protocol: \"%v\", preferred \"ip6\" connection test failed, expected success.", protocol)
+               }
+               if !strings.Contains(body, "probe_ip_protocol 6\n") {
+                       t.Fatalf("Expected IPv6, got %s", body)
+               }
+
+               // Prefer IPv4
+               module = Module{
+                       Timeout: time.Second,
+                       DNS: DNSProbe{
+                               QueryName:           "example.com",
+                               Protocol:            protocol,
+                               PreferredIpProtocol: "ip4",
+                       },
+               }
+               recorder = httptest.NewRecorder()
+               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module)
+               body = recorder.Body.String()
+               if !result {
+                       t.Fatalf("DNS protocol: \"%v\", preferred \"ip4\" connection test failed, expected success.", protocol)
+               }
+               if !strings.Contains(body, "probe_ip_protocol 4\n") {
+                       t.Fatalf("Expected IPv4, got %s", body)
+               }
+
+               // Prefer none
+               module = Module{
+                       Timeout: time.Second,
+                       DNS: DNSProbe{
+                               QueryName: "example.com",
+                               Protocol:  protocol,
+                       },
+               }
+               recorder = httptest.NewRecorder()
+               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module)
+               body = recorder.Body.String()
+               if !result {
+                       t.Fatalf("DNS protocol: \"%v\" connection test failed, expected success.", protocol)
+               }
+               if !strings.Contains(body, "probe_ip_protocol 6\n") {
+                       t.Fatalf("Expected IPv6, got %s", body)
+               }
+
+               // No protocol
+               module = Module{
+                       Timeout: time.Second,
+                       DNS: DNSProbe{
+                               QueryName: "example.com",
+                       },
+               }
+               recorder = httptest.NewRecorder()
+               result = probeDNS(net.JoinHostPort("localhost", port), recorder, module)
+               body = recorder.Body.String()
+               if protocol == "udp" {
+                       if !result {
+                               t.Fatalf("DNS test connection with protocol unspecified failed, expected success.", protocol)
+                       }
+               } else {
+                       if result {
+                               t.Fatalf("DNS test connection with protocol unspecified succeeded, expected failure.", protocol)
+                       }
+               }
+               if !strings.Contains(body, "probe_ip_protocol 6\n") {
+                       t.Fatalf("Expected IPv6, got %s", body)
+               }
+       }
+}
diff --git a/http.go b/http.go
index df8ddbadc88b98c68976af2935080db035fc0a99..61550f527c3764230b8f3792b6630a61e39e427a 100644 (file)
--- a/http.go
+++ b/http.go
@@ -18,7 +18,9 @@ import (
        "fmt"
        "io"
        "io/ioutil"
+       "net"
        "net/http"
+       "net/url"
        "regexp"
        "strings"
 
@@ -56,8 +58,55 @@ func matchRegularExpressions(reader io.Reader, config HTTPProbe) bool {
 
 func probeHTTP(target string, w http.ResponseWriter, module Module) (success bool) {
        var isSSL, redirects int
+       var dialProtocol, fallbackProtocol string
+
        config := module.HTTP
 
+       if module.HTTP.Protocol == "" {
+               module.HTTP.Protocol = "tcp"
+       }
+
+       if module.HTTP.Protocol == "tcp" && module.HTTP.PreferredIpProtocol == "" {
+               module.HTTP.PreferredIpProtocol = "ip6"
+       }
+       if module.HTTP.PreferredIpProtocol == "ip6" {
+               fallbackProtocol = "ip4"
+       } else {
+               fallbackProtocol = "ip6"
+       }
+
+       dialProtocol = module.HTTP.Protocol
+       if module.HTTP.Protocol == "tcp" {
+               target_url, err := url.Parse(target)
+               if err != nil {
+                       return false
+               }
+               target_host, _, err := net.SplitHostPort(target_url.Host)
+               // If split fails, assuming it's a hostname without port part
+               if err != nil {
+                       target_host = target_url.Host
+               }
+               ip, err := net.ResolveIPAddr(module.HTTP.PreferredIpProtocol, target_host)
+               if err != nil {
+                       ip, err = net.ResolveIPAddr(fallbackProtocol, target_host)
+                       if err != nil {
+                               return false
+                       }
+               }
+
+               if ip.IP.To4() == nil {
+                       dialProtocol = "tcp6"
+               } else {
+                       dialProtocol = "tcp4"
+               }
+       }
+
+       if dialProtocol == "tcp6" {
+               fmt.Fprintf(w, "probe_ip_protocol 6\n")
+       } else {
+               fmt.Fprintf(w, "probe_ip_protocol 4\n")
+       }
+
        client := &http.Client{
                Timeout: module.Timeout,
        }
@@ -67,8 +116,12 @@ func probeHTTP(target string, w http.ResponseWriter, module Module) (success boo
                log.Errorf("Error generating TLS config: %s", err)
                return false
        }
+       dial := func(network, address string) (net.Conn, error) {
+               return net.Dial(dialProtocol, address)
+       }
        client.Transport = &http.Transport{
                TLSClientConfig: tlsconfig,
+               Dial:            dial,
        }
 
        client.CheckRedirect = func(_ *http.Request, via []*http.Request) error {
diff --git a/icmp.go b/icmp.go
index 3253348999f9464d17658ceda20ddfc1ffc8d57b..cb8cf599644c02c4b2daf0d0c95efe0cf5e87f87 100644 (file)
--- a/icmp.go
+++ b/icmp.go
@@ -15,8 +15,10 @@ package main
 
 import (
        "bytes"
+       "fmt"
        "golang.org/x/net/icmp"
        "golang.org/x/net/ipv4"
+       "golang.org/x/net/ipv6"
        "net"
        "net/http"
        "os"
@@ -39,15 +41,63 @@ func getICMPSequence() uint16 {
 }
 
 func probeICMP(target string, w http.ResponseWriter, module Module) (success bool) {
+       var (
+               socket           *icmp.PacketConn
+               requestType      icmp.Type
+               replyType        icmp.Type
+               fallbackProtocol string
+       )
+
        deadline := time.Now().Add(module.Timeout)
-       socket, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
+
+       // Defaults to IPv4 to be compatible with older versions
+       if module.ICMP.Protocol == "" {
+               module.ICMP.Protocol = "icmp"
+       }
+
+       // In case of ICMP prefer IPv6 by default
+       if module.ICMP.Protocol == "icmp" && module.ICMP.PreferredIpProtocol == "" {
+               module.ICMP.PreferredIpProtocol = "ip6"
+       }
+
+       if module.ICMP.Protocol == "icmp4" {
+               module.ICMP.PreferredIpProtocol = "ip4"
+               fallbackProtocol = ""
+       } else if module.ICMP.Protocol == "icmp6" {
+               module.ICMP.PreferredIpProtocol = "ip6"
+               fallbackProtocol = ""
+       } else if module.ICMP.PreferredIpProtocol == "ip6" {
+               fallbackProtocol = "ip4"
+       } else {
+               fallbackProtocol = "ip6"
+       }
+
+       ip, err := net.ResolveIPAddr(module.ICMP.PreferredIpProtocol, target)
+       if err != nil && fallbackProtocol != "" {
+               ip, err = net.ResolveIPAddr(fallbackProtocol, target)
+       }
+       if err != nil {
+               log.Errorf("Error resolving address %s: %s", target, err)
+       }
+
+       if ip.IP.To4() == nil {
+               requestType = ipv6.ICMPTypeEchoRequest
+               replyType = ipv6.ICMPTypeEchoReply
+               socket, err = icmp.ListenPacket("ip6:ipv6-icmp", "::")
+               fmt.Fprintf(w, "probe_ip_protocol 6\n")
+       } else {
+               requestType = ipv4.ICMPTypeEcho
+               replyType = ipv4.ICMPTypeEchoReply
+               socket, err = icmp.ListenPacket("ip4:icmp", "0.0.0.0")
+               fmt.Fprintf(w, "probe_ip_protocol 4\n")
+       }
+
        if err != nil {
                log.Errorf("Error listening to socket: %s", err)
                return
        }
        defer socket.Close()
 
-       ip, err := net.ResolveIPAddr("ip4", target)
        if err != nil {
                log.Errorf("Error resolving address %s: %s", target, err)
                return
@@ -57,12 +107,14 @@ func probeICMP(target string, w http.ResponseWriter, module Module) (success boo
        pid := os.Getpid() & 0xffff
 
        wm := icmp.Message{
-               Type: ipv4.ICMPTypeEcho, Code: 0,
+               Type: requestType,
+               Code: 0,
                Body: &icmp.Echo{
                        ID: pid, Seq: int(seq),
                        Data: []byte("Prometheus Blackbox Exporter"),
                },
        }
+
        wb, err := wm.Marshal(nil)
        if err != nil {
                log.Errorf("Error marshalling packet for %s: %s", target, err)
@@ -74,7 +126,7 @@ func probeICMP(target string, w http.ResponseWriter, module Module) (success boo
        }
 
        // Reply should be the same except for the message type.
-       wm.Type = ipv4.ICMPTypeEchoReply
+       wm.Type = replyType
        wb, err = wm.Marshal(nil)
        if err != nil {
                log.Errorf("Error marshalling packet for %s: %s", target, err)
@@ -99,6 +151,11 @@ func probeICMP(target string, w http.ResponseWriter, module Module) (success boo
                if peer.String() != ip.String() {
                        continue
                }
+               if replyType == ipv6.ICMPTypeEchoReply {
+                       // Clear checksum to make comparison succeed.
+                       rb[2] = 0
+                       rb[3] = 0
+               }
                if bytes.Compare(rb[:n], wb) == 0 {
                        success = true
                        return
diff --git a/main.go b/main.go
index db2844cd87e0497ea6de9bd9c56e892d2bfb4551..e5a0523666dab5690aff296dc2b5e4fceb9e1912 100644 (file)
--- a/main.go
+++ b/main.go
@@ -59,6 +59,8 @@ type HTTPProbe struct {
        FailIfMatchesRegexp    []string          `yaml:"fail_if_matches_regexp"`
        FailIfNotMatchesRegexp []string          `yaml:"fail_if_not_matches_regexp"`
        TLSConfig              config.TLSConfig  `yaml:"tls_config"`
+       Protocol               string            `yaml:"protocol"`              // Defaults to "tcp".
+       PreferredIpProtocol    string            `yaml:"preferred_ip_protocol"` // Defaults to "ip6".
 }
 
 type QueryResponse struct {
@@ -67,22 +69,27 @@ type QueryResponse struct {
 }
 
 type TCPProbe struct {
-       QueryResponse []QueryResponse  `yaml:"query_response"`
-       TLS           bool             `yaml:"tls"`
-       TLSConfig     config.TLSConfig `yaml:"tls_config"`
+       QueryResponse       []QueryResponse  `yaml:"query_response"`
+       TLS                 bool             `yaml:"tls"`
+       TLSConfig           config.TLSConfig `yaml:"tls_config"`
+       Protocol            string           `yaml:"protocol"`              // Defaults to "tcp".
+       PreferredIpProtocol string           `yaml:"preferred_ip_protocol"` // Defaults to "ip6".
 }
 
 type ICMPProbe struct {
+       Protocol            string `yaml:"protocol"`              // Defaults to "icmp4".
+       PreferredIpProtocol string `yaml:"preferred_ip_protocol"` // Defaults to "ip6".
 }
 
 type DNSProbe struct {
-       Protocol           string         `yaml:"protocol"` // Defaults to "udp".
-       QueryName          string         `yaml:"query_name"`
-       QueryType          string         `yaml:"query_type"`   // Defaults to ANY.
-       ValidRcodes        []string       `yaml:"valid_rcodes"` // Defaults to NOERROR.
-       ValidateAnswer     DNSRRValidator `yaml:"validate_answer_rrs"`
-       ValidateAuthority  DNSRRValidator `yaml:"validate_authority_rrs"`
-       ValidateAdditional DNSRRValidator `yaml:"validate_additional_rrs"`
+       Protocol            string         `yaml:"protocol"` // Defaults to "udp".
+       QueryName           string         `yaml:"query_name"`
+       QueryType           string         `yaml:"query_type"`   // Defaults to ANY.
+       ValidRcodes         []string       `yaml:"valid_rcodes"` // Defaults to NOERROR.
+       ValidateAnswer      DNSRRValidator `yaml:"validate_answer_rrs"`
+       ValidateAuthority   DNSRRValidator `yaml:"validate_authority_rrs"`
+       ValidateAdditional  DNSRRValidator `yaml:"validate_additional_rrs"`
+       PreferredIpProtocol string         `yaml:"preferred_ip_protocol"` // Defaults to "ip6".
 }
 
 type DNSRRValidator struct {
diff --git a/tcp.go b/tcp.go
index bcba65dfcac835d61681cc35a54ff653577e0bbd..be1b7788c57eb737923cda841a36812f5968a271 100644 (file)
--- a/tcp.go
+++ b/tcp.go
@@ -25,21 +25,59 @@ import (
        "github.com/prometheus/common/log"
 )
 
-func dialTCP(target string, module Module) (net.Conn, error) {
+func dialTCP(target string, w http.ResponseWriter, module Module) (net.Conn, error) {
+       var dialProtocol, fallbackProtocol string
+
        dialer := &net.Dialer{Timeout: module.Timeout}
+       if module.TCP.Protocol == "" {
+               module.TCP.Protocol = "tcp"
+       }
+       if module.TCP.Protocol == "tcp" && module.TCP.PreferredIpProtocol == "" {
+               module.TCP.PreferredIpProtocol = "ip6"
+       }
+       if module.TCP.PreferredIpProtocol == "ip6" {
+               fallbackProtocol = "ip4"
+       } else {
+               fallbackProtocol = "ip6"
+       }
+
+       dialProtocol = module.TCP.Protocol
+       if module.TCP.Protocol == "tcp" {
+               target_address, _, err := net.SplitHostPort(target)
+               ip, err := net.ResolveIPAddr(module.TCP.PreferredIpProtocol, target_address)
+               if err != nil {
+                       ip, err = net.ResolveIPAddr(fallbackProtocol, target_address)
+                       if err != nil {
+                               return nil, err
+                       }
+               }
+
+               if ip.IP.To4() == nil {
+                       dialProtocol = "tcp6"
+               } else {
+                       dialProtocol = "tcp4"
+               }
+       }
+
+       if dialProtocol == "tcp6" {
+               fmt.Fprintf(w, "probe_ip_protocol 6\n")
+       } else {
+               fmt.Fprintf(w, "probe_ip_protocol 4\n")
+       }
+
        if !module.TCP.TLS {
-               return dialer.Dial("tcp", target)
+               return dialer.Dial(dialProtocol, target)
        }
        config, err := module.TCP.TLSConfig.GenerateConfig()
        if err != nil {
                return nil, err
        }
-       return tls.DialWithDialer(dialer, "tcp", target, config)
+       return tls.DialWithDialer(dialer, dialProtocol, target, config)
 }
 
 func probeTCP(target string, w http.ResponseWriter, module Module) bool {
        deadline := time.Now().Add(module.Timeout)
-       conn, err := dialTCP(target, module)
+       conn, err := dialTCP(target, w, module)
        if err != nil {
                return false
        }
index 70474e917638514d49ec79e249882d9ce36f9044..e848832517e73efd315e21a26c29e4754428ddc3 100644 (file)
@@ -16,6 +16,9 @@ package main
 import (
        "fmt"
        "net"
+       "net/http/httptest"
+       "runtime"
+       "strings"
        "testing"
        "time"
 )
@@ -36,7 +39,8 @@ func TestTCPConnection(t *testing.T) {
                conn.Close()
                ch <- struct{}{}
        }()
-       if !probeTCP(ln.Addr().String(), nil, Module{Timeout: time.Second}) {
+       recorder := httptest.NewRecorder()
+       if !probeTCP(ln.Addr().String(), recorder, Module{Timeout: time.Second}) {
                t.Fatalf("TCP module failed, expected success.")
        }
        <-ch
@@ -44,7 +48,8 @@ func TestTCPConnection(t *testing.T) {
 
 func TestTCPConnectionFails(t *testing.T) {
        // Invalid port number.
-       if probeTCP(":0", nil, Module{Timeout: time.Second}) {
+       recorder := httptest.NewRecorder()
+       if probeTCP(":0", recorder, Module{Timeout: time.Second}) {
                t.Fatalf("TCP module suceeded, expected failure.")
        }
 }
@@ -81,7 +86,8 @@ func TestTCPConnectionQueryResponseIRC(t *testing.T) {
                conn.Close()
                ch <- struct{}{}
        }()
-       if !probeTCP(ln.Addr().String(), nil, module) {
+       recorder := httptest.NewRecorder()
+       if !probeTCP(ln.Addr().String(), recorder, module) {
                t.Fatalf("TCP module failed, expected success.")
        }
        <-ch
@@ -99,7 +105,7 @@ func TestTCPConnectionQueryResponseIRC(t *testing.T) {
                conn.Close()
                ch <- struct{}{}
        }()
-       if probeTCP(ln.Addr().String(), nil, module) {
+       if probeTCP(ln.Addr().String(), recorder, module) {
                t.Fatalf("TCP module succeeded, expected failure.")
        }
        <-ch
@@ -137,10 +143,140 @@ func TestTCPConnectionQueryResponseMatching(t *testing.T) {
                conn.Close()
                ch <- version
        }()
-       if !probeTCP(ln.Addr().String(), nil, module) {
+       recorder := httptest.NewRecorder()
+       if !probeTCP(ln.Addr().String(), recorder, module) {
                t.Fatalf("TCP module failed, expected success.")
        }
        if got, want := <-ch, "OpenSSH_6.9p1"; got != want {
                t.Fatalf("Read unexpected version: got %q, want %q", got, want)
        }
 }
+
+func TestTCPConnectionProtocol(t *testing.T) {
+       // This test assumes that listening "tcp" listens both IPv6 and IPv4 traffic and
+       // localhost resolves to both 127.0.0.1 and ::1. we must skip the test if either
+       // of these isn't true. This should be true for modern Linux systems.
+       if runtime.GOOS == "dragonfly" || runtime.GOOS == "openbsd" {
+               t.Skip("IPv6 socket isn't able to accept IPv4 traffic in the system.")
+       }
+       _, err := net.ResolveIPAddr("ip6", "localhost")
+       if err != nil {
+               t.Skip("\"localhost\" doesn't resolve to ::1.")
+       }
+
+       ln, err := net.Listen("tcp", ":0")
+       if err != nil {
+               t.Fatalf("Error listening on socket: %s", err)
+       }
+       defer ln.Close()
+
+       _, port, _ := net.SplitHostPort(ln.Addr().String())
+
+       // Force IPv4
+       module := Module{
+               Timeout: time.Second,
+               TCP: TCPProbe{
+                       Protocol: "tcp4",
+               },
+       }
+
+       recorder := httptest.NewRecorder()
+       result := probeTCP(net.JoinHostPort("localhost", port), recorder, module)
+       body := recorder.Body.String()
+       if !result {
+               t.Fatalf("TCP protocol: \"tcp4\" connection test failed, expected success.")
+       }
+       if !strings.Contains(body, "probe_ip_protocol 4\n") {
+               t.Fatalf("Expected IPv4, got %s", body)
+       }
+
+       // Force IPv6
+       module = Module{
+               Timeout: time.Second,
+               TCP: TCPProbe{
+                       Protocol: "tcp6",
+               },
+       }
+
+       recorder = httptest.NewRecorder()
+       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module)
+       body = recorder.Body.String()
+       if !result {
+               t.Fatalf("TCP protocol: \"tcp6\" connection test failed, expected success.")
+       }
+       if !strings.Contains(body, "probe_ip_protocol 6\n") {
+               t.Fatalf("Expected IPv6, got %s", body)
+       }
+
+       // Prefer IPv4
+       module = Module{
+               Timeout: time.Second,
+               TCP: TCPProbe{
+                       Protocol:            "tcp",
+                       PreferredIpProtocol: "ip4",
+               },
+       }
+
+       recorder = httptest.NewRecorder()
+       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module)
+       body = recorder.Body.String()
+       if !result {
+               t.Fatalf("TCP protocol: \"tcp\", prefer: \"ip4\" connection test failed, expected success.")
+       }
+       if !strings.Contains(body, "probe_ip_protocol 4\n") {
+               t.Fatalf("Expected IPv4, got %s", body)
+       }
+
+       // Prefer IPv6
+       module = Module{
+               Timeout: time.Second,
+               TCP: TCPProbe{
+                       Protocol:            "tcp",
+                       PreferredIpProtocol: "ip6",
+               },
+       }
+
+       recorder = httptest.NewRecorder()
+       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module)
+       body = recorder.Body.String()
+       if !result {
+               t.Fatalf("TCP protocol: \"tcp\", prefer: \"ip6\" connection test failed, expected success.")
+       }
+       if !strings.Contains(body, "probe_ip_protocol 6\n") {
+               t.Fatalf("Expected IPv6, got %s", body)
+       }
+
+       // Prefer nothing
+       module = Module{
+               Timeout: time.Second,
+               TCP: TCPProbe{
+                       Protocol: "tcp",
+               },
+       }
+
+       recorder = httptest.NewRecorder()
+       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module)
+       body = recorder.Body.String()
+       if !result {
+               t.Fatalf("TCP protocol: \"tcp\" connection test failed, expected success.")
+       }
+       if !strings.Contains(body, "probe_ip_protocol 6\n") {
+               t.Fatalf("Expected IPv6, got %s", body)
+       }
+
+       // No protocol
+       module = Module{
+               Timeout: time.Second,
+               TCP:     TCPProbe{},
+       }
+
+       recorder = httptest.NewRecorder()
+       result = probeTCP(net.JoinHostPort("localhost", port), recorder, module)
+       body = recorder.Body.String()
+       if !result {
+               t.Fatalf("TCP connection test with protocol unspecified failed, expected success.")
+       }
+       if !strings.Contains(body, "probe_ip_protocol 6\n") {
+               t.Fatalf("Expected IPv6, got %s", body)
+       }
+}