[prober/tcp] get servername for TLS from target (#231)
authorTobias Hintze <thz@users.noreply.github.com>
Wed, 20 Sep 2017 10:22:49 +0000 (12:22 +0200)
committerBrian Brazil <brian.brazil@robustperception.io>
Wed, 20 Sep 2017 10:22:49 +0000 (11:22 +0100)
Because dialTCP manually resolves name to IP, the
actual tls.DialWithDialer call cannot deduce the name
from the target. This changes puts the "lost" name
into tlsConfig to fix certficate name verification.
Tests are also added which fail without and succeed
with this change.

example.yml
prober/tcp.go
prober/tcp_test.go
prober/utils_test.go

index d91c37a1551b0ff829fc8bdcc7a40d1be1535812..7d09cf7161d2719769c5040d372782127adf5308 100644 (file)
@@ -37,6 +37,11 @@ modules:
       basic_auth:
         username: "username"
         password: "mysecret"
+  tls_connect:
+    prober: tcp
+    timeout: 5s
+    tcp:
+      tls: true
   tcp_connect_example:
     prober: tcp
     timeout: 5s
index d6f0aa567f7f8dab2107db58a2d9a1a97e989f06..e9724bad183ee2c58a166b85e427669399015ec7 100644 (file)
@@ -60,6 +60,17 @@ func dialTCP(ctx context.Context, target string, module config.Module, registry
                level.Error(logger).Log("msg", "Error creating TLS configuration", "err", err)
                return nil, err
        }
+
+       if len(tlsConfig.ServerName) == 0 {
+               // If there is no `server_name` in tls_config, use
+               // targetAddress as TLS-servername. Normally tls.DialWithDialer
+               // would do this for us, but we pre-resolved the name by
+               // `chooseProtocol` and pass the IP-address for dialing (prevents
+               // resolving twice).
+               // For this reason we need to specify the original targetAddress
+               // via tlsConfig to enable hostname verification.
+               tlsConfig.ServerName = targetAddress
+       }
        timeoutDeadline, _ := ctx.Deadline()
        dialer.Deadline = timeoutDeadline
 
index 74771dd595aee70fdf0eecb79e6e4a8f015474f2..68da4760199d475da612c6a2a2ce1b43ae963e2a 100644 (file)
@@ -25,6 +25,7 @@ import (
        "time"
 
        "github.com/go-kit/kit/log"
+       "github.com/go-kit/kit/log/level"
        "github.com/prometheus/client_golang/prometheus"
        pconfig "github.com/prometheus/common/config"
 
@@ -66,6 +67,113 @@ func TestTCPConnectionFails(t *testing.T) {
        }
 }
 
+func TestTCPConnectionWithTLS(t *testing.T) {
+       ln, err := net.Listen("tcp", ":0")
+       if err != nil {
+               t.Fatalf("Error listening on socket: %s", err)
+       }
+       defer ln.Close()
+       _, listenPort, _ := net.SplitHostPort(ln.Addr().String())
+
+       testCTX, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+       defer cancel()
+
+       // Create test certificates valid for 1 day.
+       certExpiry := time.Now().AddDate(0, 0, 1)
+       testcert_pem, testkey_pem := generateTestCertificate(certExpiry, false)
+
+       // CAFile must be passed via filesystem, use a tempfile.
+       tmpCaFile, err := ioutil.TempFile("", "cafile.pem")
+       if err != nil {
+               t.Fatalf(fmt.Sprintf("Error creating CA tempfile: %s", err))
+       }
+       if _, err := tmpCaFile.Write(testcert_pem); err != nil {
+               t.Fatalf(fmt.Sprintf("Error writing CA tempfile: %s", err))
+       }
+       if err := tmpCaFile.Close(); err != nil {
+               t.Fatalf(fmt.Sprintf("Error closing CA tempfile: %s", err))
+       }
+       defer os.Remove(tmpCaFile.Name())
+
+       ch := make(chan (struct{}))
+       logger := log.NewNopLogger()
+       // Handle server side of this test.
+       serverFunc := func() {
+               conn, err := ln.Accept()
+               if err != nil {
+                       panic(fmt.Sprintf("Error accepting on socket: %s", err))
+               }
+               defer conn.Close()
+
+               testcert, err := tls.X509KeyPair(testcert_pem, testkey_pem)
+               if err != nil {
+                       panic(fmt.Sprintf("Failed to decode TLS testing keypair: %s\n", err))
+               }
+
+               // Immediately upgrade to TLS.
+               tlsConfig := &tls.Config{
+                       ServerName:   "localhost",
+                       Certificates: []tls.Certificate{testcert},
+               }
+               tlsConn := tls.Server(conn, tlsConfig)
+               defer tlsConn.Close()
+               if err := tlsConn.Handshake(); err != nil {
+                       level.Error(logger).Log("msg", "Error TLS Handshake (server) failed", "err", err)
+               } else {
+                       // Send some bytes before terminating the connection.
+                       fmt.Fprintf(tlsConn, "Hello World!\n")
+               }
+               ch <- struct{}{}
+       }
+
+       // Expect name-verified TLS connection.
+       module := config.Module{
+               TCP: config.TCPProbe{
+                       TLS: true,
+                       TLSConfig: pconfig.TLSConfig{
+                               CAFile:             tmpCaFile.Name(),
+                               InsecureSkipVerify: false,
+                       },
+               },
+       }
+
+       registry := prometheus.NewRegistry()
+       go serverFunc()
+       // Test name-verification failure (IP without IPs in cert's SAN).
+       if ProbeTCP(testCTX, ln.Addr().String(), module, registry, log.NewNopLogger()) {
+               t.Fatalf("TCP module succeeded, expected failure.")
+       }
+       <-ch
+
+       registry = prometheus.NewRegistry()
+       go serverFunc()
+       // Test name-verification with name from target.
+       target := net.JoinHostPort("localhost", listenPort)
+       if !ProbeTCP(testCTX, target, module, registry, log.NewNopLogger()) {
+               t.Fatalf("TCP module failed, expected success.")
+       }
+       <-ch
+
+       registry = prometheus.NewRegistry()
+       go serverFunc()
+       // Test name-verification against name from tls_config.
+       module.TCP.TLSConfig.ServerName = "localhost"
+       if !ProbeTCP(testCTX, ln.Addr().String(), module, registry, log.NewNopLogger()) {
+               t.Fatalf("TCP module failed, expected success.")
+       }
+       <-ch
+
+       // Check the probe_ssl_earliest_cert_expiry.
+       mfs, err := registry.Gather()
+       if err != nil {
+               t.Fatal(err)
+       }
+       expectedResults := map[string]float64{
+               "probe_ssl_earliest_cert_expiry": float64(certExpiry.Unix()),
+       }
+       checkRegistryResults(expectedResults, mfs, t)
+}
+
 func TestTCPConnectionQueryResponseStartTLS(t *testing.T) {
        ln, err := net.Listen("tcp", "localhost:0")
        if err != nil {
@@ -78,18 +186,18 @@ func TestTCPConnectionQueryResponseStartTLS(t *testing.T) {
 
        // Create test certificates valid for 1 day.
        certExpiry := time.Now().AddDate(0, 0, 1)
-       testcert_pem, testkey_pem := generateTestCertificate(certExpiry)
+       testcert_pem, testkey_pem := generateTestCertificate(certExpiry, true)
 
        // CAFile must be passed via filesystem, use a tempfile.
        tmpCaFile, err := ioutil.TempFile("", "cafile.pem")
        if err != nil {
-               panic(fmt.Sprintf("Error creating CA tempfile: %s", err))
+               t.Fatalf(fmt.Sprintf("Error creating CA tempfile: %s", err))
        }
        if _, err := tmpCaFile.Write(testcert_pem); err != nil {
-               panic(fmt.Sprintf("Error writing CA tempfile: %s", err))
+               t.Fatalf(fmt.Sprintf("Error writing CA tempfile: %s", err))
        }
        if err := tmpCaFile.Close(); err != nil {
-               panic(fmt.Sprintf("Error closing CA tempfile: %s", err))
+               t.Fatalf(fmt.Sprintf("Error closing CA tempfile: %s", err))
        }
        defer os.Remove(tmpCaFile.Name())
 
index aa4a9a1c48823eb5c8e36395e62ade6a061b4ad9..7f1c42af953dda9c428b29e4eb227400a5ef4ff4 100644 (file)
@@ -35,7 +35,7 @@ func checkRegistryResults(expRes map[string]float64, mfs []*dto.MetricFamily, t
 // Create test certificate with specified expiry date
 // Certificate will be self-signed and use localhost/127.0.0.1
 // Generated certificate and key are returned in PEM encoding
-func generateTestCertificate(expiry time.Time) ([]byte, []byte) {
+func generateTestCertificate(expiry time.Time, IPAddressSAN bool) ([]byte, []byte) {
        privatekey, err := rsa.GenerateKey(rand.Reader, 2048)
        if err != nil {
                panic(fmt.Sprintf("Error creating rsa key: %s", err))
@@ -56,8 +56,10 @@ func generateTestCertificate(expiry time.Time) ([]byte, []byte) {
                KeyUsage:    x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
        }
        cert.DNSNames = append(cert.DNSNames, "localhost")
-       cert.IPAddresses = append(cert.IPAddresses, net.ParseIP("127.0.0.1"))
-       cert.IPAddresses = append(cert.IPAddresses, net.ParseIP("::1"))
+       if IPAddressSAN {
+               cert.IPAddresses = append(cert.IPAddresses, net.ParseIP("127.0.0.1"))
+               cert.IPAddresses = append(cert.IPAddresses, net.ParseIP("::1"))
+       }
        derCert, err := x509.CreateCertificate(rand.Reader, &cert, &cert, publickey, privatekey)
        if err != nil {
                panic(fmt.Sprintf("Error signing test-certificate: %s", err))