diff --git a/core/socks.go b/core/socks.go index 39b86d6..730b753 100644 --- a/core/socks.go +++ b/core/socks.go @@ -9,6 +9,8 @@ import ( "os" "strconv" "strings" + "sync" + "time" "github.com/mythologyli/zju-connect/core/config" @@ -20,10 +22,31 @@ import ( ) type ZJUDnsResolve struct { - remoteResolver *net.Resolver + remoteUDPResolver *net.Resolver + remoteTCPResolver *net.Resolver + timer *time.Timer + useTCP bool + lock sync.RWMutex } -func (resolve ZJUDnsResolve) Resolve(ctx context.Context, host string) (context.Context, net.IP, error) { +func (resolve *ZJUDnsResolve) ResolveWithLocal(ctx context.Context, host string) (context.Context, net.IP, error) { + if target, err := net.ResolveIPAddr("ip4", host); err != nil { + log.Printf("Resolve IPv4 addr failed using local DNS: " + host + ". Try IPv6 addr.") + + if target, err = net.ResolveIPAddr("ip6", host); err != nil { + log.Printf("Resolve IPv6 addr failed using local DNS: " + host + ". Reject connection.") + return ctx, nil, err + } else { + log.Printf("%s -> %s", host, target.IP.String()) + return ctx, target.IP, nil + } + } else { + log.Printf("%s -> %s", host, target.IP.String()) + return ctx, target.IP, nil + } +} + +func (resolve *ZJUDnsResolve) Resolve(ctx context.Context, host string) (context.Context, net.IP, error) { if config.IsDnsRuleAvailable() { if ip, hasDnsRule := config.GetSingleDnsRule(host); hasDnsRule { ctx = context.WithValue(ctx, "USE_PROXY", true) @@ -50,43 +73,53 @@ func (resolve ZJUDnsResolve) Resolve(ctx context.Context, host string) (context. log.Printf("%s -> %s", host, cachedIP.String()) return ctx, cachedIP, nil } else { - targets, err := resolve.remoteResolver.LookupIP(context.Background(), "ip4", host) - if err != nil { - log.Printf("Resolve IPv4 addr failed using ZJU DNS: " + host + ", using local DNS instead.") + resolve.lock.RLock() + useTCP := resolve.useTCP + resolve.lock.RUnlock() - target, err := net.ResolveIPAddr("ip4", host) + if !useTCP { + targets, err := resolve.remoteUDPResolver.LookupIP(context.Background(), "ip4", host) if err != nil { - log.Printf("Resolve IPv4 addr failed using local DNS: " + host + ". Try IPv6 addr.") - - target, err := net.ResolveIPAddr("ip6", host) - if err != nil { - log.Printf("Resolve IPv6 addr failed using local DNS: " + host + ". Reject connection.") - return ctx, nil, err + if targets, err = resolve.remoteTCPResolver.LookupIP(context.Background(), "ip4", host); err != nil { + // all zju dns failed, so we keep do nothing but use local dns + // host ipv4 and host ipv6 don't set cache + log.Printf("Resolve IPv4 addr failed using ZJU UDP/TCP DNS: " + host + ", using local DNS instead.") + return resolve.ResolveWithLocal(ctx, host) } else { - log.Printf("%s -> %s", host, target.IP.String()) - return ctx, target.IP, nil + resolve.lock.Lock() + resolve.useTCP = true + if resolve.timer == nil { + resolve.timer = time.AfterFunc(10*time.Minute, func() { + resolve.lock.Lock() + resolve.useTCP = false + resolve.timer = nil + resolve.lock.Unlock() + }) + } + resolve.lock.Unlock() } - } else { - log.Printf("%s -> %s", host, target.IP.String()) - return ctx, target.IP, nil } - } else { + // set dns cache if tcp or udp dns success //TODO: whether we need all dns records? or only 10.0.0.0/8 ? SetDnsCache(host, targets[0]) log.Printf("%s -> %s", host, targets[0].String()) return ctx, targets[0], nil + } else { + // only try tcp and local dns + if targets, err := resolve.remoteTCPResolver.LookupIP(context.Background(), "ip4", host); err != nil { + log.Printf("Resolve IPv4 addr failed using ZJU TCP DNS: " + host + ", using local DNS instead.") + return resolve.ResolveWithLocal(ctx, host) + } else { + SetDnsCache(host, targets[0]) + log.Printf("%s -> %s", host, targets[0].String()) + return ctx, targets[0], nil + } } } } else { // because of OS cache, don't need extra dns memory cache - target, err := net.ResolveIPAddr("ip4", host) - if err != nil { - log.Printf("Resolve IPv4 addr failed using local DNS: " + host + ". Reject connection.") - return ctx, nil, err - } else { - return ctx, target.IP, nil - } + return resolve.ResolveWithLocal(ctx, host) } } @@ -100,7 +133,7 @@ func dialDirect(ctx context.Context, network, addr string) (net.Conn, error) { } func ServeSocks5(ipStack *stack.Stack, selfIp []byte, bindAddr string, dnsServer string) { - var remoteResolver = &net.Resolver{ + var remoteUDPResolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { addrDns := tcpip.FullAddress{ @@ -117,6 +150,17 @@ func ServeSocks5(ipStack *stack.Stack, selfIp []byte, bindAddr string, dnsServer return gonet.DialUDP(ipStack, &bind, &addrDns, header.IPv4ProtocolNumber) }, } + var remoteTCPResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + addrDns := tcpip.FullAddress{ + NIC: defaultNIC, + Port: uint16(53), + Addr: tcpip.Address(net.ParseIP(dnsServer).To4()), + } + return gonet.DialTCP(ipStack, addrDns, header.IPv4ProtocolNumber) + }, + } var authMethods []socks5.Authenticator if SocksUser != "" && SocksPasswd != "" { @@ -244,8 +288,11 @@ func ServeSocks5(ipStack *stack.Stack, selfIp []byte, bindAddr string, dnsServer server := socks5.NewServer( socks5.WithAuthMethods(authMethods), - socks5.WithResolver(ZJUDnsResolve{ - remoteResolver: remoteResolver, + socks5.WithResolver(&ZJUDnsResolve{ + remoteTCPResolver: remoteTCPResolver, + remoteUDPResolver: remoteUDPResolver, + useTCP: false, + timer: nil, }), socks5.WithDial(zjuDialer), socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "", log.LstdFlags))),