Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ type TLSConfig struct {
VerifyPeerCertInNames []string `json:"verifyPeerCertInNames"`
ECHServerKeys string `json:"echServerKeys"`
ECHConfigList string `json:"echConfigList"`
ECHForceQuery bool `json:"echForceQuery"`
ECHForceQuery string `json:"echForceQuery"`
ECHSocketSettings *SocketConfig `json:"echSockopt"`
}

Expand Down Expand Up @@ -494,6 +494,12 @@ func (c *TLSConfig) Build() (proto.Message, error) {
}
config.EchServerKeys = EchPrivateKey
}
switch c.ECHForceQuery {
case "none", "half", "full", "":
config.EchForceQuery = c.ECHForceQuery
default:
return nil, errors.New(`invalid "echForceQuery": `, c.ECHForceQuery)
}
config.EchForceQuery = c.ECHForceQuery
config.EchConfigList = c.ECHConfigList
if c.ECHSocketSettings != nil {
Expand Down
3 changes: 1 addition & 2 deletions transport/internet/tls/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"github.com/xtls/xray-core/features/dns"
"os"
"slices"
"strings"
Expand Down Expand Up @@ -451,7 +450,7 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
if len(c.EchConfigList) > 0 || len(c.EchServerKeys) > 0 {
err := ApplyECH(c, config)
if err != nil {
if c.EchForceQuery || errors.Cause(err) != dns.ErrEmptyResponse {
if c.EchForceQuery == "full" {
errors.LogError(context.Background(), err)
} else {
errors.LogInfo(context.Background(), err)
Expand Down
8 changes: 4 additions & 4 deletions transport/internet/tls/config.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion transport/internet/tls/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ message Config {

string ech_config_list = 19;

bool ech_force_query = 20;
string ech_force_query = 20;

SocketConfig ech_socket_settings = 21;
}
54 changes: 36 additions & 18 deletions transport/internet/tls/ech.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ import (
"encoding/base64"
"encoding/binary"
"fmt"
utls "github.com/refraction-networking/utls"
"github.com/xtls/xray-core/common/crypto"
dns2 "github.com/xtls/xray-core/features/dns"
"golang.org/x/net/http2"
"io"
"net/http"
"net/url"
Expand All @@ -21,6 +17,11 @@ import (
"sync/atomic"
"time"

utls "github.com/refraction-networking/utls"
"github.com/xtls/xray-core/common/crypto"
dns2 "github.com/xtls/xray-core/features/dns"
"golang.org/x/net/http2"

"github.com/miekg/dns"
"github.com/xtls/reality"
"github.com/xtls/reality/hpke"
Expand Down Expand Up @@ -52,10 +53,18 @@ func ApplyECH(c *Config, config *tls.Config) error {

// for client
if len(c.EchConfigList) != 0 {
ECHForceQuery := c.EchForceQuery
switch ECHForceQuery {
case "none", "half", "full":
case "":
ECHForceQuery = "none" // default to none
default:
panic("Invalid ECHForceQuery: " + c.EchForceQuery)
}
defer func() {
// if failed to get ECHConfig, use an invalid one to make connection fail
if err != nil {
if c.EchForceQuery {
if err != nil || len(ECHConfig) == 0 {
if ECHForceQuery == "full" {
ECHConfig = []byte{1, 1, 4, 5, 1, 4}
}
}
Expand Down Expand Up @@ -106,32 +115,40 @@ type echConfigRecord struct {
}

var (
// key value must be like this: "example.com|udp://1.1.1.1"
// The keys for both maps must be generated by ECHCacheKey().
GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]()
clientForECHDOH = utils.NewTypedSyncMap[string, *http.Client]()
)

// sockopt can be nil if not specified.
// if for clientForECHDOH, domain can be empty.
func ECHCacheKey(server, domain string, sockopt *internet.SocketConfig) string {
return server + "|" + domain + "|" + fmt.Sprintf("%p", sockopt)
}

// Update updates the ECH config for given domain and server.
// this method is concurrent safe, only one update request will be sent, others get the cache.
// if isLockedUpdate is true, it will not try to acquire the lock.
func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, forceQuery bool, sockopt *internet.SocketConfig) ([]byte, error) {
func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) {
if !isLockedUpdate {
c.UpdateLock.Lock()
defer c.UpdateLock.Unlock()
}
// Double check cache after acquiring lock
configRecord := c.configRecord.Load()
if configRecord.expire.After(time.Now()) {
if configRecord.expire.After(time.Now()) && configRecord.err == nil {
errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain)
return configRecord.config, configRecord.err
}
// Query ECH config from DNS server
errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server)
echConfig, ttl, err := dnsQuery(server, domain, sockopt)
if err != nil {
if forceQuery || ttl == 0 {
return nil, err
}
// if in "full", directly return
if err != nil && forceQuery == "full" {
return nil, err
}
if ttl == 0 {
ttl = dns2.DefaultTTL
}
configRecord = &echConfigRecord{
config: echConfig,
Expand All @@ -144,16 +161,16 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo

// QueryRecord returns the ECH config for given domain.
// If the record is not in cache or expired, it will query the DNS server and update the cache.
func QueryRecord(domain string, server string, forceQuery bool, sockopt *internet.SocketConfig) ([]byte, error) {
GlobalECHConfigCacheKey := domain + "|" + server + "|" + fmt.Sprintf("%p", sockopt)
func QueryRecord(domain string, server string, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) {
GlobalECHConfigCacheKey := ECHCacheKey(server, domain, sockopt)
echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey)
if !ok {
echConfigCache = &ECHConfigCache{}
echConfigCache.configRecord.Store(&echConfigRecord{})
echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache)
}
configRecord := echConfigCache.configRecord.Load()
if configRecord.expire.After(time.Now()) {
if configRecord.expire.After(time.Now()) && (configRecord.err == nil || forceQuery == "none") {
errors.LogDebug(context.Background(), "Cache hit for domain: ", domain)
return configRecord.config, configRecord.err
}
Expand Down Expand Up @@ -196,7 +213,7 @@ func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]b
return nil, 0, err
}
var client *http.Client
serverKey := server + "|" + fmt.Sprintf("%p", sockopt)
serverKey := ECHCacheKey(server, "", sockopt)
if client, _ = clientForECHDOH.Load(serverKey); client == nil {
// All traffic sent by core should via xray's internet.DialSystem
// This involves the behavior of some Android VPN GUI clients
Expand Down Expand Up @@ -307,7 +324,8 @@ func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]b
}
}
}
return nil, dns2.DefaultTTL, dns2.ErrEmptyResponse
// empty is valid, means no ECH config found
return nil, dns2.DefaultTTL, nil
}

// reference github.com/OmarTariq612/goech
Expand Down
21 changes: 5 additions & 16 deletions transport/internet/tls/ech_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tls

import (
"fmt"
"io"
"net/http"
"strings"
Expand Down Expand Up @@ -41,7 +40,7 @@ func TestECHDial(t *testing.T) {
}
wg.Wait()
// check cache
echConfigCache, ok := GlobalECHConfigCache.Load("encryptedsni.com|udp://1.1.1.1" + "|" + fmt.Sprintf("%p", config.EchSocketSettings))
echConfigCache, ok := GlobalECHConfigCache.Load(ECHCacheKey("udp://1.1.1.1", "encryptedsni.com", nil))
if !ok {
t.Error("ECH config cache not found")

Expand All @@ -60,22 +59,12 @@ func TestECHDial(t *testing.T) {
func TestECHDialFail(t *testing.T) {
config := &Config{
ServerName: "cloudflare.com",
EchConfigList: "udp://1.1.1.1",
EchConfigList: "udp://127.0.0.1",
EchForceQuery: "half",
}
TLSConfig := config.GetTLSConfig()
TLSConfig.NextProtos = []string{"http/1.1"}
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: TLSConfig,
},
}
resp, err := client.Get("https://cloudflare.com/cdn-cgi/trace")
common.Must(err)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
common.Must(err)
config.GetTLSConfig()
// check cache
echConfigCache, ok := GlobalECHConfigCache.Load("cloudflare.com|udp://1.1.1.1" + "|" + fmt.Sprintf("%p", config.EchSocketSettings))
echConfigCache, ok := GlobalECHConfigCache.Load(ECHCacheKey("udp://127.0.0.1", "cloudflare.com", nil))
if !ok {
t.Error("ECH config cache not found")
}
Expand Down
Loading