From 3f69b389ad9fa9529b233470f78d09783ea945ab Mon Sep 17 00:00:00 2001 From: Tobias Hintze Date: Wed, 20 Sep 2017 12:22:49 +0200 Subject: [PATCH] [prober/tcp] get servername for TLS from target (#231) Because dialTCP manually resolves name to IP, the actual tls.DialWithDialer call cannot deduce the name from the target. This changes puts the "lost" name into tlsConfig to fix certficate name verification. Tests are also added which fail without and succeed with this change. --- example.yml | 5 ++ prober/tcp.go | 11 ++++ prober/tcp_test.go | 116 +++++++++++++++++++++++++++++++++++++++++-- prober/utils_test.go | 8 +-- 4 files changed, 133 insertions(+), 7 deletions(-) diff --git a/example.yml b/example.yml index d91c37a..7d09cf7 100644 --- a/example.yml +++ b/example.yml @@ -37,6 +37,11 @@ modules: basic_auth: username: "username" password: "mysecret" + tls_connect: + prober: tcp + timeout: 5s + tcp: + tls: true tcp_connect_example: prober: tcp timeout: 5s diff --git a/prober/tcp.go b/prober/tcp.go index d6f0aa5..e9724ba 100644 --- a/prober/tcp.go +++ b/prober/tcp.go @@ -60,6 +60,17 @@ func dialTCP(ctx context.Context, target string, module config.Module, registry level.Error(logger).Log("msg", "Error creating TLS configuration", "err", err) return nil, err } + + if len(tlsConfig.ServerName) == 0 { + // If there is no `server_name` in tls_config, use + // targetAddress as TLS-servername. Normally tls.DialWithDialer + // would do this for us, but we pre-resolved the name by + // `chooseProtocol` and pass the IP-address for dialing (prevents + // resolving twice). + // For this reason we need to specify the original targetAddress + // via tlsConfig to enable hostname verification. + tlsConfig.ServerName = targetAddress + } timeoutDeadline, _ := ctx.Deadline() dialer.Deadline = timeoutDeadline diff --git a/prober/tcp_test.go b/prober/tcp_test.go index 74771dd..68da476 100644 --- a/prober/tcp_test.go +++ b/prober/tcp_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" "github.com/prometheus/client_golang/prometheus" pconfig "github.com/prometheus/common/config" @@ -66,6 +67,113 @@ func TestTCPConnectionFails(t *testing.T) { } } +func TestTCPConnectionWithTLS(t *testing.T) { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Error listening on socket: %s", err) + } + defer ln.Close() + _, listenPort, _ := net.SplitHostPort(ln.Addr().String()) + + testCTX, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Create test certificates valid for 1 day. + certExpiry := time.Now().AddDate(0, 0, 1) + testcert_pem, testkey_pem := generateTestCertificate(certExpiry, false) + + // CAFile must be passed via filesystem, use a tempfile. + tmpCaFile, err := ioutil.TempFile("", "cafile.pem") + if err != nil { + t.Fatalf(fmt.Sprintf("Error creating CA tempfile: %s", err)) + } + if _, err := tmpCaFile.Write(testcert_pem); err != nil { + t.Fatalf(fmt.Sprintf("Error writing CA tempfile: %s", err)) + } + if err := tmpCaFile.Close(); err != nil { + t.Fatalf(fmt.Sprintf("Error closing CA tempfile: %s", err)) + } + defer os.Remove(tmpCaFile.Name()) + + ch := make(chan (struct{})) + logger := log.NewNopLogger() + // Handle server side of this test. + serverFunc := func() { + conn, err := ln.Accept() + if err != nil { + panic(fmt.Sprintf("Error accepting on socket: %s", err)) + } + defer conn.Close() + + testcert, err := tls.X509KeyPair(testcert_pem, testkey_pem) + if err != nil { + panic(fmt.Sprintf("Failed to decode TLS testing keypair: %s\n", err)) + } + + // Immediately upgrade to TLS. + tlsConfig := &tls.Config{ + ServerName: "localhost", + Certificates: []tls.Certificate{testcert}, + } + tlsConn := tls.Server(conn, tlsConfig) + defer tlsConn.Close() + if err := tlsConn.Handshake(); err != nil { + level.Error(logger).Log("msg", "Error TLS Handshake (server) failed", "err", err) + } else { + // Send some bytes before terminating the connection. + fmt.Fprintf(tlsConn, "Hello World!\n") + } + ch <- struct{}{} + } + + // Expect name-verified TLS connection. + module := config.Module{ + TCP: config.TCPProbe{ + TLS: true, + TLSConfig: pconfig.TLSConfig{ + CAFile: tmpCaFile.Name(), + InsecureSkipVerify: false, + }, + }, + } + + registry := prometheus.NewRegistry() + go serverFunc() + // Test name-verification failure (IP without IPs in cert's SAN). + if ProbeTCP(testCTX, ln.Addr().String(), module, registry, log.NewNopLogger()) { + t.Fatalf("TCP module succeeded, expected failure.") + } + <-ch + + registry = prometheus.NewRegistry() + go serverFunc() + // Test name-verification with name from target. + target := net.JoinHostPort("localhost", listenPort) + if !ProbeTCP(testCTX, target, module, registry, log.NewNopLogger()) { + t.Fatalf("TCP module failed, expected success.") + } + <-ch + + registry = prometheus.NewRegistry() + go serverFunc() + // Test name-verification against name from tls_config. + module.TCP.TLSConfig.ServerName = "localhost" + if !ProbeTCP(testCTX, ln.Addr().String(), module, registry, log.NewNopLogger()) { + t.Fatalf("TCP module failed, expected success.") + } + <-ch + + // Check the probe_ssl_earliest_cert_expiry. + mfs, err := registry.Gather() + if err != nil { + t.Fatal(err) + } + expectedResults := map[string]float64{ + "probe_ssl_earliest_cert_expiry": float64(certExpiry.Unix()), + } + checkRegistryResults(expectedResults, mfs, t) +} + func TestTCPConnectionQueryResponseStartTLS(t *testing.T) { ln, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -78,18 +186,18 @@ func TestTCPConnectionQueryResponseStartTLS(t *testing.T) { // Create test certificates valid for 1 day. certExpiry := time.Now().AddDate(0, 0, 1) - testcert_pem, testkey_pem := generateTestCertificate(certExpiry) + testcert_pem, testkey_pem := generateTestCertificate(certExpiry, true) // CAFile must be passed via filesystem, use a tempfile. tmpCaFile, err := ioutil.TempFile("", "cafile.pem") if err != nil { - panic(fmt.Sprintf("Error creating CA tempfile: %s", err)) + t.Fatalf(fmt.Sprintf("Error creating CA tempfile: %s", err)) } if _, err := tmpCaFile.Write(testcert_pem); err != nil { - panic(fmt.Sprintf("Error writing CA tempfile: %s", err)) + t.Fatalf(fmt.Sprintf("Error writing CA tempfile: %s", err)) } if err := tmpCaFile.Close(); err != nil { - panic(fmt.Sprintf("Error closing CA tempfile: %s", err)) + t.Fatalf(fmt.Sprintf("Error closing CA tempfile: %s", err)) } defer os.Remove(tmpCaFile.Name()) diff --git a/prober/utils_test.go b/prober/utils_test.go index aa4a9a1..7f1c42a 100644 --- a/prober/utils_test.go +++ b/prober/utils_test.go @@ -35,7 +35,7 @@ func checkRegistryResults(expRes map[string]float64, mfs []*dto.MetricFamily, t // Create test certificate with specified expiry date // Certificate will be self-signed and use localhost/127.0.0.1 // Generated certificate and key are returned in PEM encoding -func generateTestCertificate(expiry time.Time) ([]byte, []byte) { +func generateTestCertificate(expiry time.Time, IPAddressSAN bool) ([]byte, []byte) { privatekey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { panic(fmt.Sprintf("Error creating rsa key: %s", err)) @@ -56,8 +56,10 @@ func generateTestCertificate(expiry time.Time) ([]byte, []byte) { KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, } cert.DNSNames = append(cert.DNSNames, "localhost") - cert.IPAddresses = append(cert.IPAddresses, net.ParseIP("127.0.0.1")) - cert.IPAddresses = append(cert.IPAddresses, net.ParseIP("::1")) + if IPAddressSAN { + cert.IPAddresses = append(cert.IPAddresses, net.ParseIP("127.0.0.1")) + cert.IPAddresses = append(cert.IPAddresses, net.ParseIP("::1")) + } derCert, err := x509.CreateCertificate(rand.Reader, &cert, &cert, publickey, privatekey) if err != nil { panic(fmt.Sprintf("Error signing test-certificate: %s", err)) -- 2.25.1