Better handling of HTTP redirects.
authorBrian Brazil <brian.brazil@robustperception.io>
Fri, 27 Sep 2019 12:44:49 +0000 (13:44 +0100)
committerBrian Brazil <brian.brazil@robustperception.io>
Fri, 11 Oct 2019 13:07:54 +0000 (14:07 +0100)
If the redirect is to a different host, don't set ServerName.
Fixes #237.

Signed-off-by: Brian Brazil <brian.brazil@robustperception.io>
prober/http.go
prober/http_test.go

index 0899231e546b8dd58214289f97b864ba70cc36ae..95d65f14a444d702e0f30669098439f7f4cea97e 100644 (file)
@@ -145,28 +145,45 @@ type roundTripTrace struct {
 
 // transport is a custom transport keeping traces for each HTTP roundtrip.
 type transport struct {
-       Transport http.RoundTripper
-       logger    log.Logger
-       traces    []*roundTripTrace
-       current   *roundTripTrace
+       Transport             http.RoundTripper
+       NoServerNameTransport http.RoundTripper
+       firstHost             string
+       logger                log.Logger
+       traces                []*roundTripTrace
+       current               *roundTripTrace
 }
 
-func newTransport(rt http.RoundTripper, logger log.Logger) *transport {
+func newTransport(rt, noServerName http.RoundTripper, logger log.Logger) *transport {
        return &transport{
-               Transport: rt,
-               logger:    logger,
-               traces:    []*roundTripTrace{},
+               Transport:             rt,
+               NoServerNameTransport: noServerName,
+               logger:                logger,
+               traces:                []*roundTripTrace{},
        }
 }
 
 // RoundTrip switches to a new trace, then runs embedded RoundTripper.
 func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
+       level.Info(t.logger).Log("msg", "Making HTTP request", "url", req.URL.String(), "host", req.Host)
+
        trace := &roundTripTrace{}
        if req.URL.Scheme == "https" {
                trace.tls = true
        }
        t.current = trace
        t.traces = append(t.traces, trace)
+
+       if t.firstHost == "" {
+               t.firstHost = req.URL.Host
+       }
+
+       if t.firstHost != req.URL.Host {
+               // This is a redirect to something other than the initial host,
+               // so TLS ServerName should not be set.
+               level.Info(t.logger).Log("msg", "Address does not match first address, not sending TLS ServerName", "first", t.firstHost, "address", req.URL.Host)
+               return t.NoServerNameTransport.RoundTrip(req)
+       }
+
        return t.Transport.RoundTrip(req)
 }
 
@@ -294,6 +311,13 @@ func ProbeHTTP(ctx context.Context, target string, module config.Module, registr
                return false
        }
 
+       httpClientConfig.TLSConfig.ServerName = ""
+       noServerName, err := pconfig.NewRoundTripperFromConfig(httpClientConfig, "http_probe", true)
+       if err != nil {
+               level.Error(logger).Log("msg", "Error generating HTTP client without ServerName", "err", err)
+               return false
+       }
+
        jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
        if err != nil {
                level.Error(logger).Log("msg", "Error generating cookiejar", "err", err)
@@ -301,12 +325,13 @@ func ProbeHTTP(ctx context.Context, target string, module config.Module, registr
        }
        client.Jar = jar
 
-       // Inject transport that tracks trace for each redirect.
-       tt := newTransport(client.Transport, logger)
+       // Inject transport that tracks traces for each redirect,
+       // and does not set TLS ServerNames on redirect if needed.
+       tt := newTransport(client.Transport, noServerName, logger)
        client.Transport = tt
 
        client.CheckRedirect = func(r *http.Request, via []*http.Request) error {
-               level.Info(logger).Log("msg", "Received redirect", "url", r.URL.String())
+               level.Info(logger).Log("msg", "Received redirect", "location", r.Response.Header.Get("Location"))
                redirects = len(via)
                if redirects > 10 || httpConfig.NoFollowRedirects {
                        level.Info(logger).Log("msg", "Not following redirect")
@@ -355,8 +380,6 @@ func ProbeHTTP(ctx context.Context, target string, module config.Module, registr
                request.Header.Set(key, value)
        }
 
-       level.Info(logger).Log("msg", "Making HTTP request", "url", request.URL.String(), "host", request.Host)
-
        trace := &httptrace.ClientTrace{
                DNSStart:             tt.DNSStart,
                DNSDone:              tt.DNSDone,
index 5dc86579a07002706d26a954f90bcbb1c5207e75..637784f35ed7df2c05bd79d4fc8e695077257990 100644 (file)
@@ -649,6 +649,27 @@ func TestHTTPUsesTargetAsTLSServerName(t *testing.T) {
        }
 }
 
+func TestRedirectToTLSHostWorks(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping network dependant test")
+       }
+       ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               http.Redirect(w, r, "https://prometheus.io", http.StatusFound)
+       }))
+       defer ts.Close()
+
+       // Follow redirect, should succeed with 200.
+       registry := prometheus.NewRegistry()
+       testCTX, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+       defer cancel()
+       result := ProbeHTTP(testCTX, ts.URL,
+               config.Module{Timeout: time.Second, HTTP: config.HTTPProbe{IPProtocolFallback: true}}, registry, log.NewNopLogger())
+       if !result {
+               t.Fatalf("Redirect test failed unexpectedly")
+       }
+
+}
+
 func TestHTTPPhases(t *testing.T) {
        ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        }))