Skip to content

Commit

Permalink
Merge pull request #34 from cxz66666/use-both-dns
Browse files Browse the repository at this point in the history
feat: support UDP/TCP DNS select [WIP]
  • Loading branch information
Mythologyli committed Sep 22, 2023
2 parents 9ed7520 + 772b5bd commit e29553b
Showing 1 changed file with 75 additions and 28 deletions.
103 changes: 75 additions & 28 deletions core/socks.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"

"github.com/mythologyli/zju-connect/core/config"

Expand All @@ -20,10 +22,31 @@ import (
)

type ZJUDnsResolve struct {
remoteResolver *net.Resolver
remoteUDPResolver *net.Resolver
remoteTCPResolver *net.Resolver
timer *time.Timer
useTCP bool
lock sync.RWMutex
}

func (resolve ZJUDnsResolve) Resolve(ctx context.Context, host string) (context.Context, net.IP, error) {
func (resolve *ZJUDnsResolve) ResolveWithLocal(ctx context.Context, host string) (context.Context, net.IP, error) {
if target, err := net.ResolveIPAddr("ip4", host); err != nil {
log.Printf("Resolve IPv4 addr failed using local DNS: " + host + ". Try IPv6 addr.")

if target, err = net.ResolveIPAddr("ip6", host); err != nil {
log.Printf("Resolve IPv6 addr failed using local DNS: " + host + ". Reject connection.")
return ctx, nil, err
} else {
log.Printf("%s -> %s", host, target.IP.String())
return ctx, target.IP, nil
}
} else {
log.Printf("%s -> %s", host, target.IP.String())
return ctx, target.IP, nil
}
}

func (resolve *ZJUDnsResolve) Resolve(ctx context.Context, host string) (context.Context, net.IP, error) {
if config.IsDnsRuleAvailable() {
if ip, hasDnsRule := config.GetSingleDnsRule(host); hasDnsRule {
ctx = context.WithValue(ctx, "USE_PROXY", true)
Expand All @@ -50,43 +73,53 @@ func (resolve ZJUDnsResolve) Resolve(ctx context.Context, host string) (context.
log.Printf("%s -> %s", host, cachedIP.String())
return ctx, cachedIP, nil
} else {
targets, err := resolve.remoteResolver.LookupIP(context.Background(), "ip4", host)
if err != nil {
log.Printf("Resolve IPv4 addr failed using ZJU DNS: " + host + ", using local DNS instead.")
resolve.lock.RLock()
useTCP := resolve.useTCP
resolve.lock.RUnlock()

target, err := net.ResolveIPAddr("ip4", host)
if !useTCP {
targets, err := resolve.remoteUDPResolver.LookupIP(context.Background(), "ip4", host)
if err != nil {
log.Printf("Resolve IPv4 addr failed using local DNS: " + host + ". Try IPv6 addr.")

target, err := net.ResolveIPAddr("ip6", host)
if err != nil {
log.Printf("Resolve IPv6 addr failed using local DNS: " + host + ". Reject connection.")
return ctx, nil, err
if targets, err = resolve.remoteTCPResolver.LookupIP(context.Background(), "ip4", host); err != nil {
// all zju dns failed, so we keep do nothing but use local dns
// host ipv4 and host ipv6 don't set cache
log.Printf("Resolve IPv4 addr failed using ZJU UDP/TCP DNS: " + host + ", using local DNS instead.")
return resolve.ResolveWithLocal(ctx, host)
} else {
log.Printf("%s -> %s", host, target.IP.String())
return ctx, target.IP, nil
resolve.lock.Lock()
resolve.useTCP = true
if resolve.timer == nil {
resolve.timer = time.AfterFunc(10*time.Minute, func() {
resolve.lock.Lock()
resolve.useTCP = false
resolve.timer = nil
resolve.lock.Unlock()
})
}
resolve.lock.Unlock()
}
} else {
log.Printf("%s -> %s", host, target.IP.String())
return ctx, target.IP, nil
}
} else {
// set dns cache if tcp or udp dns success
//TODO: whether we need all dns records? or only 10.0.0.0/8 ?
SetDnsCache(host, targets[0])
log.Printf("%s -> %s", host, targets[0].String())
return ctx, targets[0], nil
} else {
// only try tcp and local dns
if targets, err := resolve.remoteTCPResolver.LookupIP(context.Background(), "ip4", host); err != nil {
log.Printf("Resolve IPv4 addr failed using ZJU TCP DNS: " + host + ", using local DNS instead.")
return resolve.ResolveWithLocal(ctx, host)
} else {
SetDnsCache(host, targets[0])
log.Printf("%s -> %s", host, targets[0].String())
return ctx, targets[0], nil
}
}
}

} else {
// because of OS cache, don't need extra dns memory cache
target, err := net.ResolveIPAddr("ip4", host)
if err != nil {
log.Printf("Resolve IPv4 addr failed using local DNS: " + host + ". Reject connection.")
return ctx, nil, err
} else {
return ctx, target.IP, nil
}
return resolve.ResolveWithLocal(ctx, host)
}
}

Expand All @@ -100,7 +133,7 @@ func dialDirect(ctx context.Context, network, addr string) (net.Conn, error) {
}

func ServeSocks5(ipStack *stack.Stack, selfIp []byte, bindAddr string, dnsServer string) {
var remoteResolver = &net.Resolver{
var remoteUDPResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
addrDns := tcpip.FullAddress{
Expand All @@ -117,6 +150,17 @@ func ServeSocks5(ipStack *stack.Stack, selfIp []byte, bindAddr string, dnsServer
return gonet.DialUDP(ipStack, &bind, &addrDns, header.IPv4ProtocolNumber)
},
}
var remoteTCPResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
addrDns := tcpip.FullAddress{
NIC: defaultNIC,
Port: uint16(53),
Addr: tcpip.Address(net.ParseIP(dnsServer).To4()),
}
return gonet.DialTCP(ipStack, addrDns, header.IPv4ProtocolNumber)
},
}

var authMethods []socks5.Authenticator
if SocksUser != "" && SocksPasswd != "" {
Expand Down Expand Up @@ -244,8 +288,11 @@ func ServeSocks5(ipStack *stack.Stack, selfIp []byte, bindAddr string, dnsServer

server := socks5.NewServer(
socks5.WithAuthMethods(authMethods),
socks5.WithResolver(ZJUDnsResolve{
remoteResolver: remoteResolver,
socks5.WithResolver(&ZJUDnsResolve{
remoteTCPResolver: remoteTCPResolver,
remoteUDPResolver: remoteUDPResolver,
useTCP: false,
timer: nil,
}),
socks5.WithDial(zjuDialer),
socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "", log.LstdFlags))),
Expand Down

0 comments on commit e29553b

Please sign in to comment.