Revert "Use preferred protocol first when resolving hostname (#728)"
authorJulien Pivotto <roidelapluie@inuits.eu>
Fri, 28 May 2021 21:23:21 +0000 (23:23 +0200)
committerJulien Pivotto <roidelapluie@inuits.eu>
Fri, 28 May 2021 21:42:17 +0000 (23:42 +0200)
This reverts commit 847b668e93267b1b57a76db5876f123bff074315.

Signed-off-by: Julien Pivotto <roidelapluie@inuits.eu>
prober/utils.go
prober/utils_test.go

index 39aef27075b4fcd605c1302ab9da92a537bdfa04..98c9152ea36efa0d643467c103353e06a40b7ae0 100644 (file)
@@ -26,13 +26,8 @@ import (
        "github.com/prometheus/client_golang/prometheus"
 )
 
-var protocolToGauge = map[string]float64{
-       "ip4": 4,
-       "ip6": 6,
-}
-
 // Returns the IP for the IPProtocol and lookup time.
-func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol bool, target string, registry *prometheus.Registry, logger log.Logger) (ip *net.IPAddr, lookupTime float64, returnerr error) {
+func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol bool, target string, registry *prometheus.Registry, logger log.Logger) (ip *net.IPAddr, lookupTime float64, err error) {
        var fallbackProtocol string
        probeDNSLookupTimeSeconds := prometheus.NewGauge(prometheus.GaugeOpts{
                Name: "probe_dns_lookup_time_seconds",
@@ -59,48 +54,64 @@ func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol b
                IPProtocol = "ip4"
                fallbackProtocol = "ip6"
        }
-       var usedProtocol string
 
+       level.Info(logger).Log("msg", "Resolving target address", "ip_protocol", IPProtocol)
        resolveStart := time.Now()
 
        defer func() {
                lookupTime = time.Since(resolveStart).Seconds()
                probeDNSLookupTimeSeconds.Add(lookupTime)
-               if usedProtocol != "" {
-                       probeIPProtocolGauge.Set(protocolToGauge[usedProtocol])
-               }
-               if ip != nil {
-                       probeIPAddrHash.Set(ipHash(ip.IP))
-               }
        }()
 
        resolver := &net.Resolver{}
-
-       level.Info(logger).Log("msg", "Resolving target address", "ip_protocol", IPProtocol)
-       if ips, err := resolver.LookupIP(ctx, IPProtocol, target); err == nil {
-               level.Info(logger).Log("msg", "Resolved target address", "ip", ips[0].String())
-               usedProtocol = IPProtocol
-               ip = &net.IPAddr{IP: ips[0]}
-               return
-       } else if !fallbackIPProtocol {
+       ips, err := resolver.LookupIPAddr(ctx, target)
+       if err != nil {
                level.Error(logger).Log("msg", "Resolution with IP protocol failed", "err", err)
-               returnerr = fmt.Errorf("unable to find ip; no fallback: %s", err)
-               return
+               return nil, 0.0, err
        }
 
-       level.Info(logger).Log("msg", "Resolving target address", "ip_protocol", fallbackProtocol)
-       ips, err := resolver.LookupIP(ctx, fallbackProtocol, target)
-       if err != nil {
-               // This could happen when the domain don't have A and AAAA record (e.g.
-               // only have MX record).
-               level.Error(logger).Log("msg", "Resolution with IP protocol failed", "err", err)
-               returnerr = fmt.Errorf("unable to find ip; exhausted fallback: %s", err)
-               return
+       // Return the IP in the requested protocol.
+       var fallback *net.IPAddr
+       for _, ip := range ips {
+               switch IPProtocol {
+               case "ip4":
+                       if ip.IP.To4() != nil {
+                               level.Info(logger).Log("msg", "Resolved target address", "ip", ip.String())
+                               probeIPProtocolGauge.Set(4)
+                               probeIPAddrHash.Set(ipHash(ip.IP))
+                               return &ip, lookupTime, nil
+                       }
+
+                       // ip4 as fallback
+                       fallback = &ip
+
+               case "ip6":
+                       if ip.IP.To4() == nil {
+                               level.Info(logger).Log("msg", "Resolved target address", "ip", ip.String())
+                               probeIPProtocolGauge.Set(6)
+                               probeIPAddrHash.Set(ipHash(ip.IP))
+                               return &ip, lookupTime, nil
+                       }
+
+                       // ip6 as fallback
+                       fallback = &ip
+               }
+       }
+
+       // Unable to find ip and no fallback set.
+       if fallback == nil || !fallbackIPProtocol {
+               return nil, 0.0, fmt.Errorf("unable to find ip; no fallback")
+       }
+
+       // Use fallback ip protocol.
+       if fallbackProtocol == "ip4" {
+               probeIPProtocolGauge.Set(4)
+       } else {
+               probeIPProtocolGauge.Set(6)
        }
-       level.Info(logger).Log("msg", "Resolved target address", "ip", ips[0].String())
-       usedProtocol = fallbackProtocol
-       ip = &net.IPAddr{IP: ips[0]}
-       return
+       probeIPAddrHash.Set(ipHash(fallback.IP))
+       level.Info(logger).Log("msg", "Resolved target address", "ip", fallback.String())
+       return fallback, lookupTime, nil
 }
 
 func ipHash(ip net.IP) float64 {
index a395ccce4b3077360a83f43b1322bc1a424f4256..95ad8d2537ca39a45c060c4d640dd691220a6f44 100644 (file)
@@ -24,7 +24,6 @@ import (
        "math/big"
        "net"
        "os"
-       "strings"
        "testing"
        "time"
 
@@ -163,7 +162,7 @@ func TestChooseProtocol(t *testing.T) {
        registry = prometheus.NewPedanticRegistry()
 
        ip, _, err = chooseProtocol(ctx, "ip4", false, "ipv6.google.com", registry, logger)
-       if err != nil && !strings.HasPrefix(err.Error(), "unable to find ip; no fallback") {
+       if err != nil && err.Error() != "unable to find ip; no fallback" {
                t.Error(err)
        } else if err == nil {
                t.Error("should set error")
@@ -171,17 +170,6 @@ func TestChooseProtocol(t *testing.T) {
        if ip != nil {
                t.Error("without fallback it should not answer")
        }
-
-       registry = prometheus.NewPedanticRegistry()
-       ip, _, err = chooseProtocol(ctx, "ip4", true, "does-not-exist.example.com", registry, logger)
-       if err != nil && !strings.HasPrefix(err.Error(), "unable to find ip; exhausted fallback") {
-               t.Error(err)
-       } else if err == nil {
-               t.Error("should set error")
-       }
-       if ip != nil {
-               t.Error("with exhausted fallback it should not answer")
-       }
 }
 
 func checkMetrics(expected map[string]map[string]map[string]struct{}, mfs []*dto.MetricFamily, t *testing.T) {