diff --git a/README.md b/README.md index 5a4bbde..3943c30 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ this `http://server-ip:port/api/v1/scan/globalcyberalliance.org`, which will ret "alt3.aspmx.l.google.com.", "alt4.aspmx.l.google.com." ], - "spf": "v=spf1 include:_u.globalcyberalliance.org._spf.smart.ondmarc.com -all", + "spf": "v=spf1 include:_u.globalcyberalliance.org._spf.smart.ondmarc.com -all" }, "advice": { "bimi": [ @@ -203,7 +203,8 @@ Which will return a JSON response like this: ## Serve Dedicated Mailbox -You can also serve scan results via a dedicated mailbox. It is advised that you use this mailbox for this sole purpose, as all emails will be deleted at each 10 second interval. +You can also serve scan results via a dedicated mailbox. It is advised that you use this mailbox for this sole purpose, +as all emails will be deleted at each 10 second interval. ```shell dss serve mail --inboundHost "imap.gmail.com:993" --inboundPass "SomePassword" --inboundUser "SomeAddress@domain.tld" --outboundHost "smtp.gmail.com:587" --outboundPass "SomePassword" --outboundUser "SomeAddress@domain.tld" --advise @@ -223,6 +224,7 @@ You can then email this inbox from any address, and you'll receive an email back | `--dkimSelector` | | Specify a comma seperated list of DKIM selectors (default "") | | `--dnsBuffer` | | Specify the allocated buffer for DNS responses (default 4096) | | `--dnsProtocol` | | Protocol to use for DNS queries (udp, tcp, tcp-tls) (default udp) | +| `--dnssec` | | Include scan for DNSSEC records | | `--format` | `-f` | Format to print results in (yaml, json, csv) (default "yaml") | | `--nameservers` | `-n` | Use specific nameservers, in host[:port] format; may be specified multiple times | | `--outputFile` | `-o` | Output the results to a specified file (creates a file with the current unix timestamp if no file is specified) | diff --git a/cmd/dss/config.go b/cmd/dss/config.go index 2ced2b7..5a77faa 100644 --- a/cmd/dss/config.go +++ b/cmd/dss/config.go @@ -32,7 +32,7 @@ var ( case "nameservers": printToConsole("nameservers: " + cast.ToString(cfg.Nameservers)) default: - log.Fatal().Msg("unknown config key") + log.Fatal().Msg("Unknown config key") } }, } @@ -47,11 +47,11 @@ var ( case "nameservers": cfg.Nameservers = strings.Split(args[1], ",") default: - log.Fatal().Msg("unknown config key") + log.Fatal().Msg("Unknown config key") } if err := cfg.Save(); err != nil { - log.Fatal().Err(err).Msg("unable to save config") + log.Fatal().Err(err).Msg("Unable to save config") } log.Info().Msg("Config updated") @@ -93,7 +93,7 @@ func (c *Config) Load() error { // create config if it doesn't exist if _, err := os.Stat(c.path); os.IsNotExist(err) { if err = os.MkdirAll(c.dir, os.ModePerm); err != nil { - log.Fatal().Err(err).Msg("failed to create config directory") + log.Fatal().Err(err).Msg("Failed to create config directory") } if err = c.Save(); err != nil { @@ -104,11 +104,11 @@ func (c *Config) Load() error { // read config configData, err := os.ReadFile(c.path) if err != nil { - log.Fatal().Err(err).Msg("unable to read config file") + log.Fatal().Err(err).Msg("Unable to read config file") } if err = yaml.Unmarshal(configData, &c); err != nil { - log.Fatal().Err(err).Msg("unable to unmarshal config values") + log.Fatal().Err(err).Msg("Unable to unmarshal config values") } return nil @@ -117,7 +117,7 @@ func (c *Config) Load() error { func (c *Config) Save() error { configData, err := yaml.Marshal(c) if err != nil { - log.Fatal().Err(err).Msg("unable to marshal default config") + log.Fatal().Err(err).Msg("Unable to marshal default config") } return os.WriteFile(c.path, configData, os.ModePerm) diff --git a/cmd/dss/main.go b/cmd/dss/main.go index 49d95d0..2faa664 100644 --- a/cmd/dss/main.go +++ b/cmd/dss/main.go @@ -26,7 +26,7 @@ var ( Use: "dss", Short: "Scan a domain's DNS records.", Long: "Scan a domain's DNS records.\nhttps://github.com/globalcyberalliance/domain-security-scanner", - Version: "3.0.16", + Version: "3.0.17", PersistentPreRun: func(cmd *cobra.Command, args []string) { var logWriter io.Writer diff --git a/cmd/dss/scan.go b/cmd/dss/scan.go index 17f80d1..dc2d6a5 100644 --- a/cmd/dss/scan.go +++ b/cmd/dss/scan.go @@ -4,7 +4,6 @@ import ( "bufio" "os" - "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/advisor" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/model" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/scanner" "github.com/spf13/cobra" @@ -26,19 +25,17 @@ var cmdScan = &cobra.Command{ scanner.WithDNSBuffer(dnsBuffer), scanner.WithDNSProtocol(dnsProtocol), scanner.WithNameservers(nameservers), + scanner.WithCheckTLS(checkTLS), } if len(dkimSelector) > 0 { opts = append(opts, scanner.WithDKIMSelectors(dkimSelector...)) } - sc, err := scanner.New(log, timeout, opts...) if err != nil { log.Fatal().Err(err).Msg("An unexpected error occurred.") } - domainAdvisor := advisor.NewAdvisor(timeout, cache, checkTLS) - if format == "csv" && outputFile == "" { log.Info().Msg("CSV header: domain,BIMI,DKIM,DMARC,MX,SPF,TXT,error,advice") } @@ -65,7 +62,7 @@ var cmdScan = &cobra.Command{ } for _, result := range results { - printResult(result, domainAdvisor) + printResult(result, sc) } } @@ -84,12 +81,12 @@ var cmdScan = &cobra.Command{ } for _, result := range results { - printResult(result, domainAdvisor) + printResult(result, sc) } }, } -func printResult(result *scanner.Result, domainAdvisor *advisor.Advisor) { +func printResult(result *scanner.Result, sc *scanner.Scanner) { if result == nil { log.Fatal().Msg("An unexpected error occurred.") } @@ -99,7 +96,7 @@ func printResult(result *scanner.Result, domainAdvisor *advisor.Advisor) { } if advise && result.Error != scanner.ErrInvalidDomain { - resultWithAdvice.Advice = domainAdvisor.CheckAll(result.Domain, result.BIMI, result.DKIM, result.DMARC, result.MX, result.SPF) + resultWithAdvice.Advice = sc.CheckAll(result.Domain, result.BIMI, result.DKIM, result.DMARC, result.DNSSEC, result.MX, result.SPF, result.STS, result.STSPolicy) } printToConsole(resultWithAdvice) diff --git a/cmd/dss/serve.go b/cmd/dss/serve.go index ab2c7cf..29ccefa 100644 --- a/cmd/dss/serve.go +++ b/cmd/dss/serve.go @@ -3,7 +3,6 @@ package main import ( "time" - "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/advisor" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/http" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/mail" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/scanner" @@ -53,6 +52,7 @@ var ( scanner.WithDNSBuffer(dnsBuffer), scanner.WithDNSProtocol(dnsProtocol), scanner.WithNameservers(nameservers), + scanner.WithCheckTLS(checkTLS), } if len(dkimSelector) > 0 { @@ -65,9 +65,6 @@ var ( } server := http.NewServer(log, timeout, cmd.Version) - if advise { - server.Advisor = advisor.NewAdvisor(timeout, cache, checkTLS) - } server.CheckTLS = checkTLS server.Scanner = sc @@ -85,6 +82,7 @@ var ( scanner.WithDNSBuffer(dnsBuffer), scanner.WithDNSProtocol(dnsProtocol), scanner.WithNameservers(nameservers), + scanner.WithCheckTLS(checkTLS), } if len(dkimSelector) > 0 { @@ -96,7 +94,7 @@ var ( log.Fatal().Err(err).Msg("could not create domain scanner") } - mailServer, err := mail.NewMailServer(mailConfig, log, sc, advisor.NewAdvisor(timeout, cache, checkTLS)) + mailServer, err := mail.NewMailServer(mailConfig, log, sc, advise) if err != nil { log.Fatal().Err(err).Msg("could not open mail server connection") } diff --git a/pkg/advisor/advisor_test.go b/pkg/advisor/advisor_test.go deleted file mode 100644 index f2f5c0f..0000000 --- a/pkg/advisor/advisor_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package advisor - -import ( - "reflect" - "testing" - "time" -) - -func TestAdvisor_CheckDMARC(t *testing.T) { - advisor := NewAdvisor(time.Second, time.Second, false) - - t.Run("Missing", func(t *testing.T) { - expectedAdvice := []string{ - "You do not have DMARC setup!", - } - - advice := advisor.CheckDMARC("") - - if !reflect.DeepEqual(advice, expectedAdvice) { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("Malformed", func(t *testing.T) { - expectedAdvice := []string{ - "Your DMARC record appears to be malformed as no semicolons seem to be present.", - } - - advice := advisor.CheckDMARC("v=DMARC1 fo=1") - - if !reflect.DeepEqual(advice, expectedAdvice) { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("FirstTag", func(t *testing.T) { - expectedAdvice := "The beginning of your DMARC record should be v=DMARC1 with specific capitalization." - advice := advisor.CheckDMARC("v=dmarc1;") - - if advice[0] != expectedAdvice { - t.Errorf("found %v, want %v", advice[0], expectedAdvice) - } - }) - - t.Run("SecondTag", func(t *testing.T) { - expectedAdvice := "The second tag in your DMARC record must be p=none/p=quarantine/p=reject." - advice := advisor.CheckDMARC("v=DMARC1; fo=1; p=reject;") - - if advice[0] != expectedAdvice { - t.Errorf("found %v, want %v", advice[0], expectedAdvice) - } - }) - - t.Run("InvalidFailureOption", func(t *testing.T) { - expectedAdvice := "Invalid failure options specified, the record must be fo=0/fo=1/fo=d/fo=s." - advice := advisor.CheckDMARC("v=DMARC1; p=random; fo=random;") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("InvalidPercentage", func(t *testing.T) { - expectedAdvice := "Invalid report percentage specified, it must be between 0 and 100." - advice := advisor.CheckDMARC("v=DMARC1; p=none; fo=1; pct=101;") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("InvalidPolicy", func(t *testing.T) { - expectedAdvice := "Invalid DMARC policy specified, the record must be p=none/p=quarantine/p=reject." - advice := advisor.CheckDMARC("v=DMARC1; p=random; fo=1;") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("InvalidReportIntervalType", func(t *testing.T) { - expectedAdvice := "Invalid report interval specified, it must be a positive integer." - advice := advisor.CheckDMARC("v=DMARC1; p=none; ri=one;") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("InvalidReportIntervalValue", func(t *testing.T) { - expectedAdvice := "Invalid report interval specified, it must be a positive value." - advice := advisor.CheckDMARC("v=DMARC1; p=none; ri=-1;") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("InvalidRUADestinationAddress", func(t *testing.T) { - expectedAdvice := "Invalid aggregate report destination specified, it should be a valid email address." - advice := advisor.CheckDMARC("v=DMARC1; p=none; fo=1; rua=mailto:dest") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("InvalidRUADestinationFormat", func(t *testing.T) { - expectedAdvice := "Invalid aggregate report destination specified, it should begin with mailto:." - advice := advisor.CheckDMARC("v=DMARC1; p=none; fo=1; rua=dest@domain.tld") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("InvalidRUFDestinationAddress", func(t *testing.T) { - expectedAdvice := "Invalid forensic report destination specified, it should be a valid email address." - advice := advisor.CheckDMARC("v=DMARC1; p=none; fo=1; ruf=mailto:dest") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("InvalidRUFDestinationFormat", func(t *testing.T) { - expectedAdvice := "Invalid forensic report destination specified, it should begin with mailto:." - advice := advisor.CheckDMARC("v=DMARC1; p=none; fo=1; ruf=dest@domain.tld") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("InvalidSubdomainPolicy", func(t *testing.T) { - expectedAdvice := "Invalid subdomain policy specified, the record must be sp=none/sp=quarantine/sp=reject." - advice := advisor.CheckDMARC("v=DMARC1; sp=random; fo=1;") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice, expectedAdvice) - } - }) - - t.Run("MissingSubdomainPolicy", func(t *testing.T) { - expectedAdvice := "Subdomain policy isn't specified, they'll default to the main policy instead." - advice := advisor.CheckDMARC("v=DMARC1; p=reject; fo=1;") - found := false - - for _, a := range advice { - if a == expectedAdvice { - found = true - } - } - - if !found { - t.Errorf("found %v, want %v", advice[0], expectedAdvice[0]) - } - }) -} diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go new file mode 100644 index 0000000..37ebd63 --- /dev/null +++ b/pkg/dns/dns.go @@ -0,0 +1,166 @@ +package dns + +import ( + "errors" + "fmt" + "io" + "net" + "net/netip" + "strings" + "sync/atomic" + "time" + + "github.com/miekg/dns" +) + +const ( + TypeA = dns.TypeA + TypeAAAA = dns.TypeAAAA + TypeCNAME = dns.TypeCNAME + TypeMX = dns.TypeMX + TypeNS = dns.TypeNS + TypeTXT = dns.TypeTXT + TypeSRV = dns.TypeSRV +) + +func NewZoneParser(r io.Reader, origin, file string) *dns.ZoneParser { + return dns.NewZoneParser(r, origin, file) +} + +type ( + Client struct { + // buffer is used to configure the size of the buffer allocated for DNS responses. + Buffer uint16 + + // client is the underlying DNS client used for scans. + client *dns.Client + + // The index of the last-used nameserver, from the nameservers slice. + // + // This field is managed by atomic operations, and should only ever be referenced by the (*Client).getNS() + // method. + lastNameserverID uint32 + + // Nameservers is a slice of "host:port" strings of nameservers to issue queries against. + Nameservers []string + + // Protocol is used to track the initialized protocol, e.g. UDP or TCP. + Protocol string + + DKIMSelectors []string + } +) + +func New(timeout time.Duration, buffer uint16, protocol string, nameservers ...string) (*Client, error) { + if timeout <= 0 { + return nil, errors.New("timeout must be greater than 0") + } + + if buffer <= 0 { + buffer = 4096 + } + + switch protocol { + case "": + protocol = "udp" + case "udp", "tcp", "tcp-tls": + default: + return nil, fmt.Errorf("invalid DNS protocol: %s, valid options: udp, tcp, tcp-tls", protocol) + } + + parsedNameservers, err := ParseNameservers(nameservers) + if err != nil { + return nil, fmt.Errorf("failed to parse nameservers: %w", err) + } + + client := new(dns.Client) + client.Net = protocol + client.Timeout = timeout + + return &Client{ + Buffer: buffer, + client: client, + Nameservers: parsedNameservers, + }, nil +} + +func (s *Client) getNS() string { + return s.Nameservers[int(atomic.AddUint32(&s.lastNameserverID, 1))%len(s.Nameservers)] +} + +func ParseNameservers(nameservers []string) ([]string, error) { + // If the provided slice of nameservers is nil, or has zero + // elements, load up /etc/resolv.conf, and get the "index" + // directives from there. + if len(nameservers) == 0 { + // Check if /etc/resolv.conf exists. + config, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + // If /etc/resolv.conf does not exist, use Google and Cloudflare. + return []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53"}, nil + } + + nameservers = config.Servers + } + + // Make sure each of the nameservers is in the "host:port" format. + // + // The "dns" package requires that you explicitly state the port + // number for the resolvers that get queried. + for index := range nameservers { + addr, err := netip.ParseAddr(nameservers[index]) + if err != nil { + // Might contain a port. + host, port, err := net.SplitHostPort(nameservers[index]) + if err != nil { + return nil, fmt.Errorf("invalid IP address: %s", nameservers[index]) + } + + // Validate IP. + addr, err = netip.ParseAddr(host) + if err != nil { + return nil, fmt.Errorf("invalid IP address: %s", nameservers[index]) + } + + if addr.Is6() { + nameservers[index] = fmt.Sprintf("[%s]:%v", addr.String(), port) + } else { + nameservers[index] = fmt.Sprintf("%s:%v", addr.String(), port) + } + + continue + } + + if addr.Is6() { + nameservers[index] = fmt.Sprintf("[%s]:53", addr.String()) + } else { + nameservers[index] = fmt.Sprintf("%s:53", addr.String()) + } + } + + return nameservers, nil +} + +// ParseZone parses a zone file and returns the found domains. +func ParseZone(zone io.Reader) []string { + zoneParser := dns.NewZoneParser(zone, "", "") + zoneParser.SetIncludeAllowed(true) + + var domains []string + + for tok, ok := zoneParser.Next(); ok; tok, ok = zoneParser.Next() { + if tok.Header().Rrtype == dns.TypeNS { + continue + } + + domain := strings.Trim(tok.Header().Name, ".") + if !strings.Contains(domain, ".") { + // we have an NS record that serves as an anchor, and should skip it + continue + } + + domains = append(domains, domain) + } + + return domains +} diff --git a/pkg/dns/dns_test.go b/pkg/dns/dns_test.go new file mode 100644 index 0000000..a8bfc86 --- /dev/null +++ b/pkg/dns/dns_test.go @@ -0,0 +1,26 @@ +package dns + +import ( + "testing" + "time" +) + +func TestNew(t *testing.T) { + t.Run("InvalidTimeout", func(t *testing.T) { + if _, err := New(0, 0, ""); err.Error() != "timeout must be greater than 0" { + t.Errorf("expected: \"%s\", got: \"%s\"", "timeout must be greater than 0", err.Error()) + } + }) + + t.Run("InvalidProtocol", func(t *testing.T) { + if _, err := New(time.Minute, 0, "protocol"); err.Error() != "invalid DNS protocol: protocol, valid options: udp, tcp, tcp-tls" { + t.Errorf("expected: \"%s\", got: \"%s\"", "invalid DNS protocol: protocol, valid options: udp, tcp, tcp-tls", err.Error()) + } + }) + + t.Run("InvalidNameservers", func(t *testing.T) { + if _, err := New(time.Minute, 0, "", "nameserver"); err.Error() != "failed to parse nameservers: invalid IP address: nameserver" { + t.Errorf("expected: \"%s\", got: \"%s\"", "failed to parse nameservers: invalid IP address: nameserver", err.Error()) + } + }) +} diff --git a/pkg/dns/requests.go b/pkg/dns/requests.go new file mode 100644 index 0000000..8b9d17f --- /dev/null +++ b/pkg/dns/requests.go @@ -0,0 +1,318 @@ +package dns + +import ( + "fmt" + "io" + "net/http" + "strings" + + "github.com/miekg/dns" +) + +const ( + DefaultBIMIPrefix = "v=BIMI1;" + DefaultDKIMPrefix = "v=DKIM1;" + DefaultDMARCPrefix = "v=DMARC1;" + DefaultSPFPrefix = "v=spf1 " + DefaultSTSPrefix = "v=STSv1;" +) + +var ( + BIMIPrefix = DefaultBIMIPrefix + DKIMPrefix = DefaultDKIMPrefix + DMARCPrefix = DefaultDMARCPrefix + SPFPrefix = DefaultSPFPrefix + STSPrefix = DefaultSTSPrefix + + // knownDkimSelectors is a list of known DKIM selectors. + knownDkimSelectors = []string{ + "x", // Generic + "google", // Google + "selector1", // Microsoft + "selector2", // Microsoft + "s1", // Generic + "s2", // Generic + "k1", // MailChimp + "mandrill", // Mandrill + "everlytickey1", // Everlytic + "everlytickey2", // Everlytic + "dkim", // Hetzner + "mxvault", // MxVault + } + + dnssecTypes = map[uint16]string{ + dns.TypeDNSKEY: "DNSKEY", + dns.TypeRRSIG: "RRSIG", + dns.TypeDS: "DS", + dns.TypeNSEC: "NSEC", + dns.TypeNSEC3: "NSEC3", + dns.TypeCDNSKEY: "CDNSKEY", + dns.TypeCDS: "CDS", + } +) + +// TODO: we no longer disregard NXDOMAIN requests. This should be handled downstream. + +func (s *Client) Scan(domain string, recordType uint16, recursiveLookup ...bool) ([]string, error) { + recursion := true + if len(recursiveLookup) > 0 && recursiveLookup[0] == false { + recursion = false + } + + return s.getDNSRecords(domain, recordType, recursion) +} + +// getDNSRecords queries the DNS server for records of a specific type for a domain. +// It returns a slice of strings (the records) and an error if any occurred. +func (s *Client) getDNSRecords(domain string, recordType uint16, recursion bool) (records []string, err error) { + answers, err := s.GetDNSAnswers(domain, recordType) + if err != nil { + return nil, err + } + + if _, ok := dnssecTypes[recordType]; ok { + for _, answer := range answers { + // records = append(records, strings.TrimPrefix(answer.String(), answer.Header().String())) + records = append(records, answer.String()) + } + return records, nil + } + + for _, answer := range answers { + + // Recursively lookup the CNAME record until we reach the underlying DNS record. + if recursion && answer.Header().Rrtype == dns.TypeCNAME { + if t, ok := answer.(*dns.CNAME); ok { + recursiveLookupTxt, err := s.getDNSRecords(t.Target, recordType, recursion) + if err != nil { + return nil, fmt.Errorf("failed to recursively lookup txt record for %v: %w", t.Target, err) + } + + records = append(records, recursiveLookupTxt...) + + continue + } + + answer.Header().Rrtype = recordType + } + + switch record := answer.(type) { + case *dns.A: + records = append(records, record.A.String()) + case *dns.AAAA: + records = append(records, record.AAAA.String()) + case *dns.CNAME: + records = append(records, record.String()) + case *dns.MX: + records = append(records, record.Mx) + case *dns.NS: + records = append(records, record.Ns) + case *dns.TXT: + records = append(records, record.Txt...) + } + } + + return records, nil +} + +// GetDNSAnswers queries the DNS server for answers to a specific question. +// It returns a slice of dns.RR (DNS resource records) and an error if any occurred. +func (s *Client) GetDNSAnswers(domain string, recordType uint16) ([]dns.RR, error) { + req := &dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.SetEdns0(s.Buffer, true) // Specify the buffer size. + req.SetQuestion(dns.Fqdn(domain), recordType) + + in, _, err := s.client.Exchange(req, s.getNS()) + if err != nil { + return nil, err + } + + if in.Rcode != dns.RcodeSuccess { + if in.Rcode == dns.RcodeNameError { + return nil, nil + } + + return nil, fmt.Errorf("DNS query failed with rcode %v", in.Rcode) + } + + if in.MsgHdr.Truncated { + return nil, fmt.Errorf("DNS buffer %v was too small", s.Buffer) + } + + return in.Answer, nil +} + +func (s *Client) GetTypeBIMI(domain string) (string, error) { + for _, dname := range []string{ + "default._bimi." + domain, + domain, + } { + records, err := s.getDNSRecords(dname, dns.TypeTXT, true) + if err != nil { + return "", err + } + + for index, record := range records { + if strings.HasPrefix(record, BIMIPrefix) { + // TXT records can be split across multiple strings, so we need to join them + return strings.Join(records[index:], ""), nil + } + } + } + + return "", nil +} + +// GetTypeDKIM queries the DNS server for DKIM records of a domain. +// It returns a string (DKIM record) and an error if any occurred. +func (s *Client) GetTypeDKIM(domain string) (string, error) { + selectors := append(s.DKIMSelectors, knownDkimSelectors...) + + for _, selector := range selectors { + records, err := s.getDNSRecords(selector+"._domainkey."+domain, dns.TypeTXT, true) + if err != nil { + return "", err + } + + for index, record := range records { + if strings.HasPrefix(record, DKIMPrefix) { + // TXT records can be split across multiple strings, so we need to join them + return strings.Join(records[index:], ""), nil + } + } + } + + return "", nil +} + +// GetTypeDMARC queries the DNS server for DMARC records of a domain. +// It returns a string (DMARC record) and an error if any occurred. +func (s *Client) GetTypeDMARC(domain string) (string, error) { + for _, dname := range []string{ + "_dmarc." + domain, + domain, + } { + records, err := s.getDNSRecords(dname, dns.TypeTXT, true) + if err != nil { + return "", err + } + + for index, record := range records { + if strings.HasPrefix(record, DMARCPrefix) { + // TXT records can be split across multiple strings, so we need to join them + return strings.Join(records[index:], ""), nil + } + } + } + + return "", nil +} + +func (s *Client) GetTypeDNSSEC(domain string) (string, error) { + var dnssecInfo string + var errs []string + + for recordType, recordName := range dnssecTypes { + records, err := s.getDNSRecords(domain, recordType, true) + if err != nil { + errs = append(errs, fmt.Sprintf("failed to query %s: %v\n", recordName, err)) + continue + } + + for index, record := range records { + // Remove domain, TTL, class, and raw data digest. + dnssecInfo += fmt.Sprintf(" %s-%d: %v\n", recordName, index+1, record) + if record == "" { + fmt.Println("Empty record") + } + } + } + if len(errs) == 0 { + return dnssecInfo, nil + } + return dnssecInfo, fmt.Errorf("some DNSSEC record queries failed:\n%s", strings.Join(errs, "\n")) +} + +func (s *Client) GetTypeMTASTS(domain string) (string, string, error) { + for _, dname := range []string{ + "_mta-sts." + domain, + domain, + } { + records, err := s.getDNSRecords(dname, dns.TypeTXT, true) + if err != nil { + return "", "", err + } + + for index, record := range records { + if strings.HasPrefix(record, STSPrefix) { + // TXT records can be split across multiple strings, so we need to join them + response, err := http.Get("https://mta-sts." + domain + "/.well-known/mta-sts.txt") + if err != nil { + return "", "", err + } + + if response.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("failed to get mta-sts record for %v: %v", domain, response.Status) + } + + policy, err := io.ReadAll(response.Body) + if err != nil { + return "", "", err + } + cleanPolicy := strings.ReplaceAll(string(policy), "\r\n", "\n") + return strings.Join(records[index:], ""), cleanPolicy, nil + } + } + } + + return "", "", nil +} + +// GetTypeMX queries the DNS server for SPF records of a domain. +// It returns a string (SPF record) and an error if any occurred. +func (s *Client) GetTypeMX(domain string) ([]string, error) { + records, err := s.getDNSRecords(domain, dns.TypeMX, true) + if err != nil { + return nil, err + } + + return records, nil +} + +func (s *Client) GetTypeNS(domain string) ([]string, error) { + records, err := s.getDNSRecords(domain, dns.TypeNS, true) + if err != nil { + return nil, err + } + + return records, nil +} + +// GetTypeSPF queries the DNS server for SPF records of a domain. +// It returns a string (SPF record) and an error if any occurred. +func (s *Client) GetTypeSPF(domain string) (string, error) { + records, err := s.getDNSRecords(domain, dns.TypeTXT, true) + if err != nil { + return "", err + } + + for _, record := range records { + if strings.HasPrefix(record, SPFPrefix) { + if !strings.Contains(record, "redirect=") { + return record, nil + } + + parts := strings.Fields(record) + for _, part := range parts { + if strings.Contains(part, "redirect=") { + redirectDomain := strings.TrimPrefix(part, "redirect=") + return s.GetTypeSPF(redirectDomain) + } + } + } + } + + return "", nil +} diff --git a/pkg/http/scan.go b/pkg/http/scan.go index 21c1b59..02937ac 100644 --- a/pkg/http/scan.go +++ b/pkg/http/scan.go @@ -53,7 +53,7 @@ func (s *Server) registerScanRoutes() { } if s.Advisor != nil { - result.Advice = s.Advisor.CheckAll(result.ScanResult.Domain, result.ScanResult.BIMI, result.ScanResult.DKIM, result.ScanResult.DMARC, result.ScanResult.MX, result.ScanResult.SPF) + result.Advice = s.Scanner.CheckAll(result.ScanResult.Domain, result.ScanResult.BIMI, result.ScanResult.DKIM, result.ScanResult.DMARC, result.ScanResult.DNSSEC, result.ScanResult.MX, result.ScanResult.SPF, result.ScanResult.STS, result.ScanResult.STSPolicy) } resp.Body.ScanResultWithAdvice = result @@ -98,7 +98,7 @@ func (s *Server) registerScanRoutes() { } if s.Advisor != nil && result.Error != scanner.ErrInvalidDomain { - res.Advice = s.Advisor.CheckAll(result.Domain, result.BIMI, result.DKIM, result.DMARC, result.MX, result.SPF) + res.Advice = s.Scanner.CheckAll(result.Domain, result.BIMI, result.DKIM, result.DMARC, result.DNSSEC, result.MX, result.SPF, result.STS, result.STSPolicy) } resp.Body.Results = append(resp.Body.Results, res) diff --git a/pkg/http/server.go b/pkg/http/server.go index 6d20add..147acb2 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -8,7 +8,6 @@ import ( "github.com/danielgtaylor/huma/v2" "github.com/danielgtaylor/huma/v2/adapters/humachi" - "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/advisor" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/scanner" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -30,7 +29,7 @@ type Server struct { CheckTLS bool // Services used by the various HTTP routes - Advisor *advisor.Advisor + Advisor *scanner.Advisor Scanner *scanner.Scanner } diff --git a/pkg/mail/server.go b/pkg/mail/server.go index 2d24a94..e041521 100644 --- a/pkg/mail/server.go +++ b/pkg/mail/server.go @@ -6,7 +6,6 @@ import ( textTmpl "text/template" "time" - domainAdvisor "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/advisor" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/cache" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/model" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/scanner" @@ -16,7 +15,7 @@ import ( ) type Server struct { - advisor *domainAdvisor.Advisor + giveAdvice bool config Config cooldown *cache.Cache[string] interval time.Duration @@ -28,13 +27,13 @@ type Server struct { } // NewMailServer returns a new instance of a mail server. -func NewMailServer(config Config, logger zerolog.Logger, sc *scanner.Scanner, advisor *domainAdvisor.Advisor) (*Server, error) { +func NewMailServer(config Config, logger zerolog.Logger, sc *scanner.Scanner, giveAdvice bool) (*Server, error) { s := Server{ - advisor: advisor, - config: config, - cooldown: cache.New[string](1 * time.Minute), - logger: logger, - Scanner: sc, + giveAdvice: giveAdvice, + config: config, + cooldown: cache.New[string](1 * time.Minute), + logger: logger, + Scanner: sc, } client, err := s.Login() @@ -112,9 +111,8 @@ func (s *Server) handler() error { resultWithAdvice := model.ScanResultWithAdvice{ ScanResult: result, } - - if s.advisor != nil || result.Error != scanner.ErrInvalidDomain { - resultWithAdvice.Advice = s.advisor.CheckAll(result.Domain, result.BIMI, result.DKIM, result.DMARC, result.MX, result.SPF) + if s.giveAdvice || result.Error != scanner.ErrInvalidDomain { + resultWithAdvice.Advice = s.Scanner.CheckAll(result.Domain, result.BIMI, result.DKIM, result.DMARC, result.DNSSEC, result.MX, result.SPF, result.STS, result.STSPolicy) } if err = s.SendMail(sender, resultWithAdvice); err != nil { diff --git a/pkg/mail/template.go b/pkg/mail/template.go index c1663de..b0b47af 100644 --- a/pkg/mail/template.go +++ b/pkg/mail/template.go @@ -7,8 +7,8 @@ import ( htmlTmpl "html/template" textTmpl "text/template" - "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/advisor" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/model" + "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/scanner" ) var ( @@ -50,7 +50,7 @@ func (s *Server) getMailContents(result model.ScanResultWithAdvice) (string, str var htmlBytes, textBytes bytes.Buffer if result.Advice == nil { - result.Advice = &advisor.Advice{} + result.Advice = &scanner.Advice{} } mailData := struct { @@ -73,7 +73,7 @@ func (s *Server) getMailContents(result model.ScanResultWithAdvice) (string, str // prevent template errors if result.Advice == nil { - result.Advice = &advisor.Advice{} + result.Advice = &scanner.Advice{} } if err := s.templateHTML.Execute(&htmlBytes, mailData); err != nil { diff --git a/pkg/model/scan.go b/pkg/model/scan.go index 305f3f1..a1b0d4f 100644 --- a/pkg/model/scan.go +++ b/pkg/model/scan.go @@ -3,13 +3,12 @@ package model import ( "strings" - "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/advisor" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/scanner" ) type ScanResultWithAdvice struct { ScanResult *scanner.Result `json:"scanResult" yaml:"scanResult" doc:"The results of scanning a domain's DNS records."` - Advice *advisor.Advice `json:"advice,omitempty" yaml:"advice,omitempty" doc:"The advice for the domain's DNS records."` + Advice *scanner.Advice `json:"advice,omitempty" yaml:"advice,omitempty" doc:"The advice for the domain's DNS records."` } func (s *ScanResultWithAdvice) CSV() []string { @@ -19,23 +18,24 @@ func (s *ScanResultWithAdvice) CSV() []string { for _, value := range s.Advice.Domain { advice += "Domain: " + value + "; " } - for _, value := range s.Advice.BIMI { advice += "BIMI: " + value + "; " } - for _, value := range s.Advice.DKIM { advice += "DKIM: " + value + "; " } - for _, value := range s.Advice.DMARC { advice += "DMARC: " + value + "; " } - + for _, value := range s.Advice.DNSSEC { + advice += "DNSSEC: " + value + "; " + } + for _, value := range s.Advice.MTASTS { + advice += "MTA-STS: " + value + "; " + } for _, value := range s.Advice.MX { advice += "MX: " + value + "; " } - for _, value := range s.Advice.SPF { advice += "SPF: " + value + "; " } diff --git a/pkg/advisor/advisor.go b/pkg/scanner/advisor.go similarity index 67% rename from pkg/advisor/advisor.go rename to pkg/scanner/advisor.go index b77e6ee..91e77a7 100644 --- a/pkg/advisor/advisor.go +++ b/pkg/scanner/advisor.go @@ -1,4 +1,4 @@ -package advisor +package scanner import ( "crypto/tls" @@ -18,6 +18,7 @@ import ( var emailRegex = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$") type ( + // Advisor config options. Advisor struct { consumerDomains map[string]struct{} consumerDomainsMutex *sync.Mutex @@ -32,8 +33,10 @@ type ( BIMI []string `json:"bimi,omitempty" yaml:"bimi,omitempty" doc:"BIMI advice." example:"Your BIMI record looks good! No further action needed."` DKIM []string `json:"dkim,omitempty" yaml:"dkim,omitempty" doc:"DKIM advice." example:"DKIM is setup for this email server. However, if you have other 3rd party systems, please send a test email to confirm DKIM is setup properly."` DMARC []string `json:"dmarc,omitempty" yaml:"dmarc,omitempty" doc:"DMARC advice." example:"You are currently at the lowest level and receiving reports, which is a great starting point. Please make sure to review the reports, make the appropriate adjustments, and move to either quarantine or reject soon."` + MTASTS []string `json:"mta-sts,omitempty" yaml:"mta-sts,omitempty" doc:"MTA-STS advice." example:"MTA-STS seems to be setup correctly! No further action needed."` MX []string `json:"mx,omitempty" yaml:"mx,omitempty" doc:"MX advice." example:"You have a multiple mail servers setup! No further action needed."` SPF []string `json:"spf,omitempty" yaml:"spf,omitempty" doc:"SPF advice." example:"SPF seems to be setup correctly! No further action needed."` + DNSSEC []string `json:"dnssec,omitempty" yaml:"dnssec,omitempty" doc:"DNSSEC advice." example:"DNSSEC seems to be setup correctly! No further action needed."` } // dmarc represents the structure of a DMARC record. @@ -52,9 +55,8 @@ type ( } ) -func NewAdvisor(timeout time.Duration, cacheLifetime time.Duration, checkTLS bool) *Advisor { +func NewAdvisor(timeout time.Duration, cacheLifetime time.Duration) *Advisor { advisor := Advisor{ - checkTLS: checkTLS, consumerDomains: make(map[string]struct{}), consumerDomainsMutex: &sync.Mutex{}, dialer: &net.Dialer{Timeout: timeout}, @@ -69,38 +71,48 @@ func NewAdvisor(timeout time.Duration, cacheLifetime time.Duration, checkTLS boo return &advisor } -func (a *Advisor) CheckAll(domain, bimi, dkim, dmarc string, mx []string, spf string) *Advice { +func (s *Scanner) CheckAll(domain, bimi, dkim, dmarc string, dnssec string, mx []string, spf string, sts string, stsPolicy string) *Advice { advice := &Advice{} var wg sync.WaitGroup - wg.Add(6) + wg.Add(8) go func() { - advice.Domain = a.CheckDomain(domain) + advice.Domain = s.CheckDomain(domain) wg.Done() }() go func() { - advice.BIMI = a.CheckBIMI(bimi) + advice.BIMI = s.CheckBIMI(bimi) wg.Done() }() go func() { - advice.DKIM = a.CheckDKIM(dkim) + advice.DKIM = s.CheckDKIM(dkim) wg.Done() }() go func() { - advice.DMARC = a.CheckDMARC(dmarc) + advice.DMARC = s.CheckDMARC(dmarc) wg.Done() }() go func() { - advice.MX = a.CheckMX(mx) + advice.DNSSEC = s.CheckDNSSEC(dnssec) wg.Done() }() go func() { - advice.SPF = a.CheckSPF(spf) + advice.MTASTS = s.CheckMTASTS(sts, stsPolicy) + wg.Done() + }() + + go func() { + advice.MX = s.CheckMX(mx) + wg.Done() + }() + + go func() { + advice.SPF = s.CheckSPF(spf) wg.Done() }() @@ -109,7 +121,7 @@ func (a *Advisor) CheckAll(domain, bimi, dkim, dmarc string, mx []string, spf st return advice } -func (a *Advisor) CheckBIMI(bimi string) (advice []string) { +func (s *Scanner) CheckBIMI(bimi string) (advice []string) { if len(bimi) == 0 { return []string{"We couldn't detect any active BIMI record for your domain. Please visit https://dmarcguide.globalcyberalliance.org to fix this."} } @@ -129,35 +141,37 @@ func (a *Advisor) CheckBIMI(bimi string) (advice []string) { svgFound = true tagValue := strings.TrimPrefix(tag, "l=") - // download SVG logo + // Download SVG logo. response, err := http.Head(tagValue) if err != nil || response == nil { advice = append(advice, "Your SVG logo could not be downloaded.") continue } - defer response.Body.Close() if response.StatusCode != http.StatusOK { advice = append(advice, "Your SVG logo could not be downloaded.") + response.Body.Close() continue } if response.ContentLength > int64(32*1024) { advice = append(advice, "Your SVG logo exceeds the maximum of 32KB.") } + + response.Body.Close() } if strings.Contains(tag, "a=") { vmcFound = true tagValue := strings.TrimPrefix(tag, "a=") - // download VMC cert + // Download VMC cert. response, err := http.Head(tagValue) if err != nil || response == nil { advice = append(advice, "Your VMC certificate could not be downloaded.") continue } - defer response.Body.Close() + response.Body.Close() if response.StatusCode != http.StatusOK { advice = append(advice, "Your VMC certificate could not be downloaded.") @@ -187,7 +201,7 @@ func (a *Advisor) CheckBIMI(bimi string) (advice []string) { return advice } -func (a *Advisor) CheckDKIM(dkim string) (advice []string) { +func (s *Scanner) CheckDKIM(dkim string) (advice []string) { if dkim == "" { return []string{"We couldn't detect any active DKIM record for your domain. Due to how DKIM works, we only lookup common/known DKIM selectors (such as x, selector1, google). Visit https://dmarcguide.globalcyberalliance.org for more info on how to configure DKIM for your domain."} } @@ -224,7 +238,7 @@ func (a *Advisor) CheckDKIM(dkim string) (advice []string) { return advice } -func (a *Advisor) CheckDMARC(record string) (advice []string) { +func (s *Scanner) CheckDMARC(record string) (advice []string) { if record == "" { return []string{"You do not have DMARC setup!"} } @@ -236,6 +250,7 @@ func (a *Advisor) CheckDMARC(record string) (advice []string) { dmarcRecord := dmarc{} parts := strings.Split(record, ";") ruaExists := strings.Contains(record, "rua=") + var vFound, pFound bool for index, part := range parts { keyValue := strings.SplitN(strings.TrimSpace(part), "=", 2) @@ -248,12 +263,14 @@ func (a *Advisor) CheckDMARC(record string) (advice []string) { switch key { case "v": + vFound = true if index != 0 || value != "DMARC1" { dmarcRecord.Advice = append(dmarcRecord.Advice, "The beginning of your DMARC record should be v=DMARC1 with specific capitalization.") } dmarcRecord.Version = value case "p": + pFound = true if index != 1 { dmarcRecord.Advice = append(dmarcRecord.Advice, "The second tag in your DMARC record must be p=none/p=quarantine/p=reject.") } @@ -324,8 +341,16 @@ func (a *Advisor) CheckDMARC(record string) (advice []string) { dmarcRecord.Advice = append(dmarcRecord.Advice, "Invalid failure options specified, the record must be fo=0/fo=1/fo=d/fo=s.") } case "aspf": + if value != "r" && value != "s" { + dmarcRecord.Advice = append(dmarcRecord.Advice, "aspf value is invalid, must be 'r' or 's'") + } + dmarcRecord.ASPF = value case "adkim": + if value != "r" && value != "s" { + dmarcRecord.Advice = append(dmarcRecord.Advice, "adkim value is invalid, must be 'r' or 's'") + } + dmarcRecord.ADKIM = value case "ri": ri, err := strconv.Atoi(value) @@ -341,6 +366,14 @@ func (a *Advisor) CheckDMARC(record string) (advice []string) { } } + if !vFound { + dmarcRecord.Advice = append(dmarcRecord.Advice, "The first tag in your DMARC record should be v=DMARC1") + } + + if !pFound { + dmarcRecord.Advice = append(dmarcRecord.Advice, "No DMARC policy found, record must contain p=none/p=quarantine/p=reject") + } + if len(dmarcRecord.AggregateReportDestination) == 0 { dmarcRecord.Advice = append(dmarcRecord.Advice, "Consider specifying a 'rua' tag for aggregate reporting.") } @@ -360,16 +393,23 @@ func (a *Advisor) CheckDMARC(record string) (advice []string) { return dmarcRecord.Advice } -func (a *Advisor) CheckDomain(domain string) (advice []string) { - a.consumerDomainsMutex.Lock() - if _, ok := a.consumerDomains[domain]; ok { - a.consumerDomainsMutex.Unlock() +func (s *Scanner) CheckDNSSEC(dnssec string) (advice []string) { + if dnssec == "" { + return []string{"We couldn't detect any active DNSSEC record for your domain."} + } + return []string{"DNSSEC seems to be setup correctly! No further action needed."} +} + +func (s *Scanner) CheckDomain(domain string) (advice []string) { + s.advisor.consumerDomainsMutex.Lock() + if _, ok := s.advisor.consumerDomains[domain]; ok { + s.advisor.consumerDomainsMutex.Unlock() return []string{"Consumer based accounts (i.e gmail.com, yahoo.com, etc) are controlled by the vendor. They are responsible for setting DKIM, SPF and DMARC capabilities on their domains."} } - a.consumerDomainsMutex.Unlock() + s.advisor.consumerDomainsMutex.Unlock() - if a.checkTLS { - advice = append(advice, a.checkHostTLS(domain, 443)...) + if s.advisor.checkTLS { + advice = append(advice, s.checkHostTLS(domain, 443)...) } if len(advice) == 0 { @@ -379,7 +419,7 @@ func (a *Advisor) CheckDomain(domain string) (advice []string) { return advice } -func (a *Advisor) CheckMX(mx []string) (advice []string) { +func (s *Scanner) CheckMX(mx []string) (advice []string) { switch len(mx) { case 0: return []string{"You do not have any mail servers setup, so you cannot receive email at this domain."} @@ -389,10 +429,10 @@ func (a *Advisor) CheckMX(mx []string) (advice []string) { advice = []string{"You have multiple mail servers setup, which is recommended."} } - if a.checkTLS { + if s.advisor.checkTLS { for _, serverAddress := range mx { // prepend the hostname to the advice line - mxAdvice := a.checkMailTls(serverAddress) + mxAdvice := s.checkMailTls(serverAddress) for _, serverAdvice := range mxAdvice { // strip the trailing dot from DNS records advice = append(advice, serverAddress[:len(serverAddress)-1]+": "+serverAdvice) @@ -422,47 +462,71 @@ func (a *Advisor) CheckMX(mx []string) (advice []string) { return advice } -func (a *Advisor) CheckSPF(spf string) []string { +func (s *Scanner) CheckSPF(spf string) []string { if spf == "" { return []string{"We couldn't detect any active SPF record for your domain. Please visit https://dmarcguide.globalcyberalliance.org to fix this."} } + var advice []string + + if !strings.HasPrefix(spf, "v=spf1") { + advice = append(advice, "Your SPF record should begin with v=spf1") + } + + lookupCount := 0 + lookupError := s.checkSPFLookup(spf, []string{}, &lookupCount) + if lookupError != "" { + advice = append(advice, lookupError) + } + + if lookupCount > 10 { + advice = append(advice, "Your SPF record contains "+strconv.Itoa(lookupCount)+" DNS lookups, which is more than the 10 lookup limit. Your SPF record check will fail, consider using 'ip4' and 'ip6' mechanisms instead.") + } + + if strings.Contains(spf, "ptr") { + advice = append(advice, "The 'ptr' mechanism is deprecated, and is unreliable. It is strongly recommended to not use it.") + } + if strings.Contains(spf, "all") { if strings.Contains(spf, "+all") { - return []string{"Your SPF record contains the +all tag. It is strongly recommended that this be changed to either -all or ~all. The +all tag allows for any system regardless of SPF to send mail on the organization’s behalf."} + advice = append(advice, "Your SPF record contains the +all tag. It is strongly recommended that this be changed to either -all or ~all. The +all tag allows for any system regardless of SPF to send mail on the organization’s behalf.") } } else { - return []string{"Your SPF record is missing the all tag. Please visit https://dmarcguide.globalcyberalliance.org to fix this."} + advice = append(advice, "Your SPF record is missing the all tag. Please visit https://dmarcguide.globalcyberalliance.org to fix this.") + } + + if len(advice) == 0 { + advice = append(advice, "SPF seems to be setup correctly! No further action needed.") } - return []string{"SPF seems to be setup correctly! No further action needed."} + return advice } -func (a *Advisor) checkHostTLS(hostname string, port int) (advice []string) { - // strip the trailing dot from DNS records +func (s *Scanner) checkHostTLS(hostname string, port int) (advice []string) { + // Strip the trailing dot from DNS records. if string(hostname[len(hostname)-1]) == "." { hostname = hostname[:len(hostname)-1] } - // check if the advice is already in the cache - tlsAdvice := a.tlsCacheHost.Get(hostname) + // Check if the advice is already in the cache. + tlsAdvice := s.advisor.tlsCacheHost.Get(hostname) if tlsAdvice != nil { return *tlsAdvice } - // set the advice in the cache after the function returns + // Set the advice in the cache after the function returns. defer func() { - a.tlsCacheHost.Set(hostname, &advice) + s.advisor.tlsCacheHost.Set(hostname, &advice) }() if port == 0 { port = 443 } - conn, err := tls.DialWithDialer(a.dialer, "tcp", hostname+":"+cast.ToString(port), nil) + conn, err := tls.DialWithDialer(s.advisor.dialer, "tcp", hostname+":"+cast.ToString(port), nil) if err != nil { if strings.Contains(err.Error(), "no such host") { - // fill variable to satisfy deferred cache fill + // Fill variable to satisfy deferred cache fill. advice = []string{hostname + " could not be reached"} return advice } @@ -470,7 +534,7 @@ func (a *Advisor) checkHostTLS(hostname string, port int) (advice []string) { if strings.Contains(err.Error(), "certificate is not trusted") || strings.Contains(err.Error(), "failed to verify certificate") { advice = append(advice, "No valid certificate could be found.") - conn, err = tls.DialWithDialer(a.dialer, "tcp", hostname+":"+cast.ToString(port), &tls.Config{InsecureSkipVerify: true}) + conn, err = tls.DialWithDialer(s.advisor.dialer, "tcp", hostname+":"+cast.ToString(port), &tls.Config{InsecureSkipVerify: true}) if err != nil { return advice } @@ -485,26 +549,26 @@ func (a *Advisor) checkHostTLS(hostname string, port int) (advice []string) { return advice } -func (a *Advisor) checkMailTls(hostname string) (advice []string) { - // strip the trailing dot from DNS records +func (s *Scanner) checkMailTls(hostname string) (advice []string) { + // Strip the trailing dot from DNS records. if string(hostname[len(hostname)-1]) == "." { hostname = hostname[:len(hostname)-1] } - // check if the advice is already in the cache - tlsAdvice := a.tlsCacheMail.Get(hostname) + // Check if the advice is already in the cache. + tlsAdvice := s.advisor.tlsCacheMail.Get(hostname) if tlsAdvice != nil { return *tlsAdvice } - // set the advice in the cache after the function returns + // Set the advice in the cache after the function returns. defer func() { - a.tlsCacheMail.Set(hostname, &advice) + s.advisor.tlsCacheMail.Set(hostname, &advice) }() - conn, err := a.dialer.Dial("tcp", hostname+":25") + conn, err := s.advisor.dialer.Dial("tcp", hostname+":25") if err != nil { - // fill variable to satisfy deferred cache fill + // Fill variable to satisfy deferred cache fill. if strings.Contains(err.Error(), "i/o timeout") { advice = []string{"Failed to reach domain before timeout"} } else { @@ -517,7 +581,7 @@ func (a *Advisor) checkMailTls(hostname string) (advice []string) { client, err := smtp.NewClient(conn, hostname) if err != nil { - // fill variable to satisfy deferred cache fill + // Fill variable to satisfy deferred cache fill. advice = []string{"Failed to reach domain"} return advice } @@ -531,16 +595,16 @@ func (a *Advisor) checkMailTls(hostname string) (advice []string) { if strings.Contains(err.Error(), "certificate is not trusted") || strings.Contains(err.Error(), "failed to verify certificate") { advice = append(advice, "No valid certificate could be found.") - // close the existing connection and create a new one as we can't reuse it in the same way as the checkHostTLS function + // Close the existing connection and create a new one as we can't reuse it in the same way as the checkHostTLS function. if err = conn.Close(); err != nil { - // fill variable to satisfy deferred cache fill + // Fill variable to satisfy deferred cache fill. advice = append(advice, "Failed to re-attempt connection without certificate verification") return advice } - conn, err = a.dialer.Dial("tcp", hostname+"25") + conn, err = s.advisor.dialer.Dial("tcp", hostname+"25") if err != nil { - // fill variable to satisfy deferred cache fill + // Fill variable to satisfy deferred cache fill. advice = []string{"Failed to reach domain"} return advice } @@ -548,20 +612,20 @@ func (a *Advisor) checkMailTls(hostname string) (advice []string) { client, err = smtp.NewClient(conn, hostname) if err != nil { - // fill variable to satisfy deferred cache fill + // Fill variable to satisfy deferred cache fill. advice = []string{"Failed to reach domain"} return advice } - // retry with InsecureSkipVerify + // Retry with InsecureSkipVerify. tlsConfig.InsecureSkipVerify = true if err = client.StartTLS(tlsConfig); err != nil { - // fill variable to satisfy deferred cache fill + // Fill variable to satisfy deferred cache fill. advice = append(advice, "Failed to start TLS connection") return advice } } else { - // fill variable to satisfy deferred cache fill + // Fill variable to satisfy deferred cache fill. advice = []string{"Failed to start TLS connection: " + err.Error()} return advice } @@ -574,6 +638,116 @@ func (a *Advisor) checkMailTls(hostname string) (advice []string) { return advice } +func (s *Scanner) CheckMTASTS(record string, policy string) (advice []string) { + if record == "" { + return []string{"You do not have MTA-STS setup!"} + } + + if !strings.HasPrefix(record, "v=STSv1") { + advice = append(advice, "The beginning of your MTA-STS record should be v=STSv1 with specific capitalization.") + } + + if !strings.Contains(record, "id=") { + advice = append(advice, "The MTA-STS record should contain an 'id' tag.") + } + + if policy == "" { + advice = append(advice, "The MTA-STS policy is missing.") + return advice + } + lines := strings.Split(policy, "\n") + requiredFields := []string{"version:", "mode:", "mx:", "max_age:"} + for _, field := range requiredFields { + found := false + for _, line := range lines { + if strings.HasPrefix(line, field) { + found = true + if field == "mode:" { + value, _ := strings.CutPrefix(line, field) + value = strings.TrimSpace(value) + switch value { + case "enforce": + break + case "testing": + advice = append(advice, "The MTA-STS policy is in testing mode. This means that the policy will not be enforced.") + case "none": + advice = append(advice, "The MTA-STS policy is in none mode. This means that the policy will not be used.") + default: + advice = append(advice, "The MTA-STS policy mode is invalid. It should be either enforce, testing or none.") + } + } + } + } + if !found { + advice = append(advice, "The MTA-STS policy is missing the "+field+" field.") + } + } + + if len(advice) == 0 { + return []string{"MTA-STS seems to be setup correctly! No further action needed."} + } + return advice +} + +func (s *Scanner) checkSPFLookup(spf string, lookupParents []string, lookupCount *int) string { + // Get DNS lookups from record. + parts := strings.Split(spf, " ") + for _, part := range parts { + var keyValue []string + + if strings.Contains(part, ":") { + keyValue = strings.Split(part, ":") + } else { + keyValue = strings.Split(part, "=") + } + + key := strings.ToLower(keyValue[0]) + + switch key { + case "a", + "mx", + "ptr", + "exists", + "redirect": + *lookupCount++ + + case "include": + *lookupCount++ + + value := keyValue[1] + for _, parent := range lookupParents { + if parent == value { + return "SPF record contains cyclid lookup chain begining at" + key + "." + } + } + + // get spf record of target + // txtRecords, err := net.LookupTXT(value) + newSPF, err := s.dnsClient.GetTypeSPF(value) + if err != nil { + return "Error when accessing SPF record for " + value + "." + } + if spf == "" { + return "Could not find required SPF record at " + value + "." + } + + // var newSPF string + // for index, record := range txtRecords { + // if strings.HasPrefix(record, "v=spf1") { + // newSPF = txtRecords[index] + // break + // } + // } + + lookupError := s.checkSPFLookup(newSPF, append(lookupParents, value), lookupCount) + if lookupError != "" { + return lookupError + } + } + } + return "" +} + func checkTLSVersion(tlsVersion uint16) string { switch tlsVersion { case tls.VersionTLS10: diff --git a/pkg/advisor/domains.go b/pkg/scanner/domains.go similarity index 99% rename from pkg/advisor/domains.go rename to pkg/scanner/domains.go index a0b1246..bcfc1ee 100644 --- a/pkg/advisor/domains.go +++ b/pkg/scanner/domains.go @@ -1,4 +1,4 @@ -package advisor +package scanner var consumerDomainList = []string{ "126.com", diff --git a/pkg/scanner/options.go b/pkg/scanner/options.go index 2f59578..d4e110b 100644 --- a/pkg/scanner/options.go +++ b/pkg/scanner/options.go @@ -3,14 +3,12 @@ package scanner import ( "errors" "fmt" - "net" - "net/netip" "regexp" "runtime" "strings" "time" - "github.com/miekg/dns" + "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/dns" ) // OverwriteOption allows the caller to overwrite an existing option. @@ -47,6 +45,16 @@ func WithConcurrentScans(quota uint16) Option { } } +func WithCheckTLS(checkTLS bool) Option { + return func(s *Scanner) error { + if s.advisor == nil { + return errors.New("advisor not initialized") + } + s.advisor.checkTLS = checkTLS + return nil + } +} + // WithDKIMSelectors allows the caller to specify which DKIM selectors to // scan for (falling back to the default selectors if none are provided). func WithDKIMSelectors(selectors ...string) Option { @@ -62,7 +70,7 @@ func WithDKIMSelectors(selectors ...string) Option { } } - s.dkimSelectors = selectors + s.dnsClient.DKIMSelectors = selectors return nil } @@ -75,7 +83,7 @@ func WithDNSBuffer(bufferSize uint16) Option { return fmt.Errorf("invalid DNS buffer size: %d", bufferSize) } - s.dnsBuffer = bufferSize + s.dnsClient.Buffer = bufferSize return nil } @@ -88,7 +96,7 @@ func WithDNSProtocol(protocol string) Option { switch protocol { case "udp", "tcp", "tcp-tls": - s.dnsClient.Net = protocol + s.dnsClient.Protocol = protocol default: return fmt.Errorf("invalid DNS protocol: %s, valid options: udp, tcp, tcp-tls", protocol) } @@ -102,57 +110,12 @@ func WithDNSProtocol(protocol string) Option { // the nameservers specified in /etc/resolv.conf. func WithNameservers(nameservers []string) Option { return func(s *Scanner) error { - // If the provided slice of nameservers is nil, or has zero - // elements, load up /etc/resolv.conf, and get the "index" - // directives from there. - if len(nameservers) == 0 { - // check if /etc/resolv.conf exists - config, err := dns.ClientConfigFromFile("/etc/resolv.conf") - if err != nil { - // if /etc/resolv.conf does not exist, use Google and Cloudflare - s.nameservers = []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53"} - return nil - } - - nameservers = config.Servers - } - - // Make sure each of the nameservers is in the "host:port" format. - // - // The "dns" package requires that you explicitly state the port - // number for the resolvers that get queried. - for index := range nameservers { - addr, err := netip.ParseAddr(nameservers[index]) - if err != nil { - // might contain a port - host, port, err := net.SplitHostPort(nameservers[index]) - if err != nil { - return fmt.Errorf("invalid IP address: %s", nameservers[index]) - } - - // validate IP - addr, err = netip.ParseAddr(host) - if err != nil { - return fmt.Errorf("invalid IP address: %s", nameservers[index]) - } - - if addr.Is6() { - nameservers[index] = fmt.Sprintf("[%s]:%v", addr.String(), port) - } else { - nameservers[index] = fmt.Sprintf("%s:%v", addr.String(), port) - } - - continue - } - - if addr.Is6() { - nameservers[index] = fmt.Sprintf("[%s]:53", addr.String()) - } else { - nameservers[index] = fmt.Sprintf("%s:53", addr.String()) - } + nameservers, err := dns.ParseNameservers(nameservers) + if err != nil { + return fmt.Errorf("failed to parse nameservers: %w", err) } - s.nameservers = nameservers + s.dnsClient.Nameservers = nameservers return nil } diff --git a/pkg/scanner/options_test.go b/pkg/scanner/options_test.go index 0da884f..90ac844 100644 --- a/pkg/scanner/options_test.go +++ b/pkg/scanner/options_test.go @@ -71,7 +71,7 @@ func TestOptionWithDKIMSelectors(t *testing.T) { t.Run("ValidDKIMSelectors", func(t *testing.T) { scanner, err := New(logger, timeout, WithDKIMSelectors("selector1", "selector1._google")) require.NoError(t, err) - require.Equal(t, []string{"selector1", "selector1._google"}, scanner.dkimSelectors) + require.Equal(t, []string{"selector1", "selector1._google"}, scanner.dnsClient.DKIMSelectors) }) t.Run("InvalidDKIMSelectorEndingCharacter", func(t *testing.T) { @@ -112,19 +112,19 @@ func TestOptionWithDNSBuffer(t *testing.T) { t.Run("BufferWithinLimit", func(t *testing.T) { scanner, err := New(logger, timeout, WithDNSBuffer(2048)) require.NoError(t, err) - require.Equal(t, uint16(2048), scanner.dnsBuffer) + require.Equal(t, uint16(2048), scanner.dnsClient.Buffer) }) t.Run("BufferExceedsLimit", func(t *testing.T) { scanner, err := New(logger, timeout, WithDNSBuffer(5000)) require.NoError(t, err) - require.Equal(t, uint16(5000), scanner.dnsBuffer) + require.Equal(t, uint16(5000), scanner.dnsClient.Buffer) }) t.Run("BufferAtLimit", func(t *testing.T) { scanner, err := New(logger, timeout, WithDNSBuffer(4096)) require.NoError(t, err) - require.Equal(t, uint16(4096), scanner.dnsBuffer) + require.Equal(t, uint16(4096), scanner.dnsClient.Buffer) }) } @@ -140,19 +140,19 @@ func TestOptionWithDNSProtocol(t *testing.T) { t.Run("ValidProtocolTCP", func(t *testing.T) { scanner, err := New(logger, timeout, WithDNSProtocol("TCP")) require.NoError(t, err) - require.Equal(t, "tcp", scanner.dnsClient.Net) + require.Equal(t, "tcp", scanner.dnsClient.Protocol) }) t.Run("ValidProtocolTCPWithTLS", func(t *testing.T) { scanner, err := New(logger, timeout, WithDNSProtocol("TCP-tls")) require.NoError(t, err) - require.Equal(t, "tcp-tls", scanner.dnsClient.Net) + require.Equal(t, "tcp-tls", scanner.dnsClient.Protocol) }) t.Run("ValidProtocolUDP", func(t *testing.T) { scanner, err := New(logger, timeout, WithDNSProtocol("UDP")) require.NoError(t, err) - require.Equal(t, "udp", scanner.dnsClient.Net) + require.Equal(t, "udp", scanner.dnsClient.Protocol) }) } @@ -163,7 +163,7 @@ func TestOptionWithNameservers(t *testing.T) { t.Run("EmptyNameservers", func(t *testing.T) { scanner, err := New(logger, timeout, WithNameservers(nil)) require.NoError(t, err) - require.NotEmpty(t, scanner.nameservers) + require.NotEmpty(t, scanner.dnsClient.Nameservers) }) t.Run("InvalidNameservers", func(t *testing.T) { @@ -174,24 +174,24 @@ func TestOptionWithNameservers(t *testing.T) { t.Run("ValidNameserverWithPort", func(t *testing.T) { scanner, err := New(logger, timeout, WithNameservers([]string{"8.8.8.8:53"})) require.NoError(t, err) - require.Equal(t, []string{"8.8.8.8:53"}, scanner.nameservers) + require.Equal(t, []string{"8.8.8.8:53"}, scanner.dnsClient.Nameservers) }) t.Run("ValidNameserverWithoutPort", func(t *testing.T) { scanner, err := New(logger, timeout, WithNameservers([]string{"8.8.8.8"})) require.NoError(t, err) - require.Equal(t, []string{"8.8.8.8:53"}, scanner.nameservers) + require.Equal(t, []string{"8.8.8.8:53"}, scanner.dnsClient.Nameservers) }) t.Run("ValidNameserverWithPortV6", func(t *testing.T) { scanner, err := New(logger, timeout, WithNameservers([]string{"[2001:4860:4860::8888]:53"})) require.NoError(t, err) - require.Equal(t, []string{"[2001:4860:4860::8888]:53"}, scanner.nameservers) + require.Equal(t, []string{"[2001:4860:4860::8888]:53"}, scanner.dnsClient.Nameservers) }) t.Run("ValidNameserverWithoutPortV6", func(t *testing.T) { scanner, err := New(logger, timeout, WithNameservers([]string{"2001:4860:4860::8888"})) require.NoError(t, err) - require.Equal(t, []string{"[2001:4860:4860::8888]:53"}, scanner.nameservers) + require.Equal(t, []string{"[2001:4860:4860::8888]:53"}, scanner.dnsClient.Nameservers) }) } diff --git a/pkg/scanner/requests.go b/pkg/scanner/requests.go deleted file mode 100644 index 4c1a3ca..0000000 --- a/pkg/scanner/requests.go +++ /dev/null @@ -1,208 +0,0 @@ -package scanner - -import ( - "fmt" - "regexp" - "strings" - - "github.com/miekg/dns" -) - -const ( - DefaultBIMIPrefix = "v=BIMI1;" - DefaultDKIMPrefix = "v=DKIM1;" -) - -var ( - BIMIPrefix = DefaultBIMIPrefix - DKIMPrefix = DefaultDKIMPrefix - DMARCPrefix = regexp.MustCompile(`^v\s*=\s*DMARC1`) // Matches v=DMARC1 with whitespace (RFC7489). - SPFPrefix = regexp.MustCompile(`^v=(?i)spf1`) - - // knownDkimSelectors is a list of known DKIM selectors. - knownDkimSelectors = []string{ - "x", // Generic - "google", // Google - "selector1", // Microsoft - "selector2", // Microsoft - "s1", // Generic - "s2", // Generic - "k1", // MailChimp - "mandrill", // Mandrill - "everlytickey1", // Everlytic - "everlytickey2", // Everlytic - "dkim", // Hetzner - "mxvault", // MxVault - } -) - -// getDNSRecords queries the DNS server for records of a specific type for a domain. -// It returns a slice of strings (the records) and an error if any occurred. -func (s *Scanner) getDNSRecords(domain string, recordType uint16) (records []string, err error) { - answers, err := s.getDNSAnswers(domain, recordType) - if err != nil { - return nil, err - } - - for _, answer := range answers { - if answer.Header().Rrtype == dns.TypeCNAME { - if t, ok := answer.(*dns.CNAME); ok { - recursiveLookupTxt, err := s.getDNSRecords(t.Target, recordType) - if err != nil { - return nil, fmt.Errorf("failed to recursively lookup txt record for %v: %w", t.Target, err) - } - - records = append(records, recursiveLookupTxt...) - - continue - } - - answer.Header().Rrtype = recordType - } - - switch dnsRec := answer.(type) { - case *dns.A: - records = append(records, dnsRec.A.String()) - case *dns.AAAA: - records = append(records, dnsRec.AAAA.String()) - case *dns.MX: - records = append(records, dnsRec.Mx) - case *dns.NS: - records = append(records, dnsRec.Ns) - case *dns.TXT: - records = append(records, dnsRec.Txt...) - } - } - - return records, nil -} - -// getDNSAnswers queries the DNS server for answers to a specific question. -// It returns a slice of dns.RR (DNS resource records) and an error if any occurred. -func (s *Scanner) getDNSAnswers(domain string, recordType uint16) ([]dns.RR, error) { - req := &dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.SetEdns0(s.dnsBuffer, true) // increases the response buffer size - req.SetQuestion(dns.Fqdn(domain), recordType) - - in, _, err := s.dnsClient.Exchange(req, s.getNS()) - if err != nil { - return nil, err - } - - if in.Rcode != dns.RcodeSuccess { - // disregard NXDOMAIN errors - if in.Rcode == dns.RcodeNameError { - return nil, nil - } - - return nil, fmt.Errorf("DNS query failed with rcode %v", in.Rcode) - } - - if in.MsgHdr.Truncated && s.dnsBuffer < 4096 { - s.logger.Warn().Msg(fmt.Sprintf("DNS buffer %v was too small for %v, retrying with larger buffer (4096)", s.dnsBuffer, domain)) - - req.SetEdns0(4096, true) - - in, _, err = s.dnsClient.Exchange(req, s.getNS()) - if err != nil { - return nil, err - } - } - - return in.Answer, nil -} - -func (s *Scanner) getTypeBIMI(domain string) (string, error) { - for _, dname := range []string{ - "default._bimi." + domain, - domain, - } { - records, err := s.getDNSRecords(dname, dns.TypeTXT) - if err != nil { - return "", err - } - - for index, record := range records { - if strings.HasPrefix(record, BIMIPrefix) { - // TXT records can be split across multiple strings, so we need to join them - return strings.Join(records[index:], ""), nil - } - } - } - - return "", nil -} - -// getTypeDKIM queries the DNS server for DKIM records of a domain. -// It returns a string (DKIM record) and an error if any occurred. -func (s *Scanner) getTypeDKIM(domain string) (string, error) { - selectors := append(s.dkimSelectors, knownDkimSelectors...) - - for _, selector := range selectors { - records, err := s.getDNSRecords(selector+"._domainkey."+domain, dns.TypeTXT) - if err != nil { - return "", err - } - - for index, record := range records { - if strings.HasPrefix(record, DKIMPrefix) { - // TXT records can be split across multiple strings, so we need to join them - return strings.Join(records[index:], ""), nil - } - } - } - - return "", nil -} - -// getTypeDMARC queries the DNS server for DMARC records of a domain. -// It returns a string (DMARC record) and an error if any occurred. -func (s *Scanner) getTypeDMARC(domain string) (string, error) { - for _, dname := range []string{ - "_dmarc." + domain, - domain, - } { - records, err := s.getDNSRecords(dname, dns.TypeTXT) - if err != nil { - return "", err - } - - for index, record := range records { - if DMARCPrefix.Match([]byte(record)) { - // TXT records can be split across multiple strings, so we need to join them - return strings.Join(records[index:], ""), nil - } - } - } - - return "", nil -} - -// getTypeSPF queries the DNS server for SPF records of a domain. -// It returns a string (SPF record) and an error if any occurred. -func (s *Scanner) getTypeSPF(domain string) (string, error) { - records, err := s.getDNSRecords(domain, dns.TypeTXT) - if err != nil { - return "", err - } - - for _, record := range records { - if SPFPrefix.Match([]byte(record)) { - if !strings.Contains(record, "redirect=") { - return record, nil - } - - parts := strings.Fields(record) - for _, part := range parts { - if strings.Contains(part, "redirect=") { - redirectDomain := strings.TrimPrefix(part, "redirect=") - return s.getTypeSPF(redirectDomain) - } - } - } - } - - return "", nil -} diff --git a/pkg/scanner/scanner.go b/pkg/scanner/scanner.go index 9948db1..d89dda9 100644 --- a/pkg/scanner/scanner.go +++ b/pkg/scanner/scanner.go @@ -6,11 +6,10 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "time" "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/cache" - "github.com/miekg/dns" + "github.com/globalcyberalliance/domain-security-scanner/v3/pkg/dns" "github.com/panjf2000/ants/v2" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -29,32 +28,22 @@ type ( // cacheDuration is the time-to-live for cache entries. cacheDuration time.Duration - // dkimSelectors is used to specify where a DKIM record is hosted for a specific domain. - dkimSelectors []string - // DNS client shared by all goroutines the scanner spawns. dnsClient *dns.Client // dnsBuffer is used to configure the size of the buffer allocated for DNS responses. dnsBuffer uint16 - // The index of the last-used nameserver, from the nameservers slice. - // - // This field is managed by atomic operations, and should only ever be referenced by the (*Scanner).getNS() - // method. - lastNameserverIndex uint32 - // logger is the logger for the scanner. logger zerolog.Logger - // nameservers is a slice of "host:port" strings of nameservers to issue queries against. - nameservers []string - // pool is the pool of workers for the scanner. pool *ants.Pool // poolSize is the size of the pool of workers for the scanner. poolSize uint16 + + advisor *Advisor } // Option defines a functional configuration type for a *Scanner. @@ -62,14 +51,17 @@ type ( // Result holds the results of scanning a domain's DNS records. Result struct { - Domain string `json:"domain" yaml:"domain,omitempty" doc:"The domain name being scanned." example:"example.com"` - Error string `json:"error,omitempty" yaml:"error,omitempty" doc:"An error message if the scan failed." example:"invalid domain name"` - BIMI string `json:"bimi,omitempty" yaml:"bimi,omitempty" doc:"The BIMI record for the domain." example:"https://example.com/bimi.svg"` - DKIM string `json:"dkim,omitempty" yaml:"dkim,omitempty" doc:"The DKIM record for the domain." example:"v=DKIM1; k=rsa; p=MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA"` - DMARC string `json:"dmarc,omitempty" yaml:"dmarc,omitempty" doc:"The DMARC record for the domain." example:"v=DMARC1; p=none"` - MX []string `json:"mx,omitempty" yaml:"mx,omitempty" doc:"The MX records for the domain." example:"aspmx.l.google.com"` - NS []string `json:"ns,omitempty" yaml:"ns,omitempty" doc:"The NS records for the domain." example:"ns1.example.com"` - SPF string `json:"spf,omitempty" yaml:"spf,omitempty" doc:"The SPF record for the domain." example:"v=spf1 include:_spf.google.com ~all"` + Domain string `json:"domain" yaml:"domain,omitempty" doc:"The domain name being scanned." example:"example.com"` + Error string `json:"error,omitempty" yaml:"error,omitempty" doc:"An error message if the scan failed." example:"invalid domain name"` + BIMI string `json:"bimi,omitempty" yaml:"bimi,omitempty" doc:"The BIMI record for the domain." example:"https://example.com/bimi.svg"` + DKIM string `json:"dkim,omitempty" yaml:"dkim,omitempty" doc:"The DKIM record for the domain." example:"v=DKIM1; k=rsa; p=MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA"` + DMARC string `json:"dmarc,omitempty" yaml:"dmarc,omitempty" doc:"The DMARC record for the domain." example:"v=DMARC1; p=none"` + DNSSEC string `json:"dnssec,omitempty" yaml:"dnssec,omitempty" doc:"The DNSSEC records for the domain." example:""` // TODO: add example + MX []string `json:"mx,omitempty" yaml:"mx,omitempty" doc:"The MX records for the domain." example:"aspmx.l.google.com"` + NS []string `json:"ns,omitempty" yaml:"ns,omitempty" doc:"The NS records for the domain." example:"ns1.example.com"` + SPF string `json:"spf,omitempty" yaml:"spf,omitempty" doc:"The SPF record for the domain." example:"v=spf1 include:_spf.google.com ~all"` + STS string `json:"mta-sts,omitempty" yaml:"mta-sts,omitempty" doc:"The MTA-STS record for the domain." example:"v=STSv1; id=20210803T010200;"` + STSPolicy string `json:"mta-sts-policy,omitempty" yaml:"mta-sts-policy,omitempty" doc:"The MTA-STS policy for the domain." example:"version: STSv1\nmode: enforce\nmx: mail.example.com\nmx: *.example.net\nmax_age: 86400\n"` } ) @@ -78,28 +70,29 @@ func New(logger zerolog.Logger, timeout time.Duration, opts ...Option) (*Scanner return nil, errors.New("timeout must be greater than 0") } - dnsClient := new(dns.Client) - dnsClient.Net = "udp" - dnsClient.Timeout = timeout + dnsClient, err := dns.New(timeout, 4096, "") + if err != nil { + return nil, fmt.Errorf("failed to create DNS client: %w", err) + } scanner := &Scanner{ - dnsClient: dnsClient, - dnsBuffer: 4096, - logger: logger, - nameservers: []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53"}, // Set the default nameservers to Google and Cloudflare - poolSize: uint16(runtime.NumCPU()), + dnsClient: dnsClient, + dnsBuffer: 4096, + logger: logger, + poolSize: uint16(runtime.NumCPU()), } + scanner.advisor = NewAdvisor(timeout, scanner.cacheDuration) for _, opt := range opts { if err := opt(scanner); err != nil { return nil, errors.Wrap(err, "apply option") } } - // Initialize cache + // Initialize cache. scanner.cache = cache.New[Result](scanner.cacheDuration) - // Create a new pool of workers for the scanner + // Create a new pool of workers for the scanner. pool, err := ants.NewPool(int(scanner.poolSize), ants.WithExpiryDuration(timeout), ants.WithPanicHandler(func(err interface{}) { scanner.logger.Error().Err(errors.New(cast.ToString(err))).Msg("unrecoverable panic occurred while analysing pcap") })) @@ -162,13 +155,13 @@ func (s *Scanner) Scan(domains ...string) ([]*Result, error) { }() } - // check that the domain name is valid - result.NS, err = s.getDNSRecords(domainToScan, dns.TypeNS) + // Check that the domain name is valid. + result.NS, err = s.dnsClient.GetTypeNS(domainToScan) if err != nil || len(result.NS) == 0 { - // check if TXT records exist, as the nameserver check won't work for subdomains - records, err := s.getDNSAnswers(domainToScan, dns.TypeTXT) + // Check if TXT records exist, as the nameserver check won't work for subdomains. + records, err := s.dnsClient.GetDNSAnswers(domainToScan, dns.TypeTXT) if err != nil || len(records) == 0 { - // fill variable to satisfy deferred cache fill + // Fill variable to satisfy deferred cache fill. result = &Result{ Domain: domainToScan, Error: ErrInvalidDomain, @@ -184,21 +177,21 @@ func (s *Scanner) Scan(domains ...string) ([]*Result, error) { var errs []string scanWg := sync.WaitGroup{} - scanWg.Add(5) + scanWg.Add(7) - // Get BIMI record + // Get BIMI record. go func() { defer scanWg.Done() - result.BIMI, err = s.getTypeBIMI(domainToScan) + result.BIMI, err = s.dnsClient.GetTypeBIMI(domainToScan) if err != nil { errs = append(errs, "bimi:"+err.Error()) } }() - // Get DKIM record + // Get DKIM record. go func() { defer scanWg.Done() - result.DKIM, err = s.getTypeDKIM(domainToScan) + result.DKIM, err = s.dnsClient.GetTypeDKIM(domainToScan) if err != nil { errs = append(errs, "dkim:"+err.Error()) } @@ -207,25 +200,43 @@ func (s *Scanner) Scan(domains ...string) ([]*Result, error) { // Get DMARC record go func() { defer scanWg.Done() - result.DMARC, err = s.getTypeDMARC(domainToScan) + result.DMARC, err = s.dnsClient.GetTypeDMARC(domainToScan) if err != nil { errs = append(errs, "dmarc:"+err.Error()) } }() - // Get MX records + // Get DNSSEC records. + go func() { + defer scanWg.Done() + result.DNSSEC, err = s.dnsClient.GetTypeDNSSEC(domainToScan) + if err != nil { + errs = append(errs, "dnssec:"+err.Error()) + } + }() + + // Get MTA-STS record. + go func() { + defer scanWg.Done() + result.STS, result.STSPolicy, err = s.dnsClient.GetTypeMTASTS(domainToScan) + if err != nil { + errs = append(errs, "mta-sts:"+err.Error()) + } + }() + + // Get MX records. go func() { defer scanWg.Done() - result.MX, err = s.getDNSRecords(domainToScan, dns.TypeMX) + result.MX, err = s.dnsClient.GetTypeMX(domainToScan) if err != nil { errs = append(errs, "mx:"+err.Error()) } }() - // Get SPF record + // Get SPF record. go func() { defer scanWg.Done() - result.SPF, err = s.getTypeSPF(domainToScan) + result.SPF, err = s.dnsClient.GetTypeSPF(domainToScan) if err != nil { errs = append(errs, "spf:"+err.Error()) } @@ -277,13 +288,9 @@ func (s *Scanner) ScanZone(zone io.Reader) ([]*Result, error) { return s.Scan(domains...) } -// Close closes the scanner +// Close closes the scanner. func (s *Scanner) Close() { s.pool.Release() s.cache.Flush() s.logger.Debug().Msg("scanner closed") } - -func (s *Scanner) getNS() string { - return s.nameservers[int(atomic.AddUint32(&s.lastNameserverIndex, 1))%len(s.nameservers)] -}