123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364 |
- package socks5
- import (
- "fmt"
- "io"
- "net"
- "strconv"
- "strings"
- "golang.org/x/net/context"
- )
- const (
- ConnectCommand = uint8(1)
- BindCommand = uint8(2)
- AssociateCommand = uint8(3)
- ipv4Address = uint8(1)
- fqdnAddress = uint8(3)
- ipv6Address = uint8(4)
- )
- const (
- successReply uint8 = iota
- serverFailure
- ruleFailure
- networkUnreachable
- hostUnreachable
- connectionRefused
- ttlExpired
- commandNotSupported
- addrTypeNotSupported
- )
- var (
- unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
- )
- type AddressRewriter interface {
- Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
- }
- type AddrSpec struct {
- FQDN string
- IP net.IP
- Port int
- }
- func (a *AddrSpec) String() string {
- if a.FQDN != "" {
- return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
- }
- return fmt.Sprintf("%s:%d", a.IP, a.Port)
- }
- func (a AddrSpec) Address() string {
- if 0 != len(a.IP) {
- return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
- }
- return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
- }
- type Request struct {
-
- Version uint8
-
- Command uint8
-
- AuthContext *AuthContext
-
- RemoteAddr *AddrSpec
-
- DestAddr *AddrSpec
-
- realDestAddr *AddrSpec
- bufConn io.Reader
- }
- type conn interface {
- Write([]byte) (int, error)
- RemoteAddr() net.Addr
- }
- func NewRequest(bufConn io.Reader) (*Request, error) {
-
- header := []byte{0, 0, 0}
- if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
- return nil, fmt.Errorf("Failed to get command version: %v", err)
- }
-
- if header[0] != socks5Version {
- return nil, fmt.Errorf("Unsupported command version: %v", header[0])
- }
-
- dest, err := readAddrSpec(bufConn)
- if err != nil {
- return nil, err
- }
- request := &Request{
- Version: socks5Version,
- Command: header[1],
- DestAddr: dest,
- bufConn: bufConn,
- }
- return request, nil
- }
- func (s *Server) handleRequest(req *Request, conn conn) error {
- ctx := context.Background()
-
- dest := req.DestAddr
- if dest.FQDN != "" {
- ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
- if err != nil {
- if err := sendReply(conn, hostUnreachable, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
- }
- ctx = ctx_
- dest.IP = addr
- }
-
- req.realDestAddr = req.DestAddr
- if s.config.Rewriter != nil {
- ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
- }
-
- switch req.Command {
- case ConnectCommand:
- return s.handleConnect(ctx, conn, req)
- case BindCommand:
- return s.handleBind(ctx, conn, req)
- case AssociateCommand:
- return s.handleAssociate(ctx, conn, req)
- default:
- if err := sendReply(conn, commandNotSupported, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Unsupported command: %v", req.Command)
- }
- }
- func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
-
- if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
- if err := sendReply(conn, ruleFailure, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
- } else {
- ctx = ctx_
- }
-
- dial := s.config.Dial
- if dial == nil {
- dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
- return net.Dial(net_, addr)
- }
- }
- target, err := dial(ctx, "tcp", req.realDestAddr.Address())
- if err != nil {
- msg := err.Error()
- resp := hostUnreachable
- if strings.Contains(msg, "refused") {
- resp = connectionRefused
- } else if strings.Contains(msg, "network is unreachable") {
- resp = networkUnreachable
- }
- if err := sendReply(conn, resp, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
- }
- defer target.Close()
-
- local := target.LocalAddr().(*net.TCPAddr)
- bind := AddrSpec{IP: local.IP, Port: local.Port}
- if err := sendReply(conn, successReply, &bind); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
-
- errCh := make(chan error, 2)
- go proxy(target, req.bufConn, errCh)
- go proxy(conn, target, errCh)
-
- for i := 0; i < 2; i++ {
- e := <-errCh
- if e != nil {
-
- return e
- }
- }
- return nil
- }
- func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
-
- if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
- if err := sendReply(conn, ruleFailure, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
- } else {
- ctx = ctx_
- }
-
- if err := sendReply(conn, commandNotSupported, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return nil
- }
- func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
-
- if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
- if err := sendReply(conn, ruleFailure, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
- } else {
- ctx = ctx_
- }
-
- if err := sendReply(conn, commandNotSupported, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return nil
- }
- func readAddrSpec(r io.Reader) (*AddrSpec, error) {
- d := &AddrSpec{}
-
- addrType := []byte{0}
- if _, err := r.Read(addrType); err != nil {
- return nil, err
- }
-
- switch addrType[0] {
- case ipv4Address:
- addr := make([]byte, 4)
- if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
- return nil, err
- }
- d.IP = net.IP(addr)
- case ipv6Address:
- addr := make([]byte, 16)
- if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
- return nil, err
- }
- d.IP = net.IP(addr)
- case fqdnAddress:
- if _, err := r.Read(addrType); err != nil {
- return nil, err
- }
- addrLen := int(addrType[0])
- fqdn := make([]byte, addrLen)
- if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
- return nil, err
- }
- d.FQDN = string(fqdn)
- default:
- return nil, unrecognizedAddrType
- }
-
- port := []byte{0, 0}
- if _, err := io.ReadAtLeast(r, port, 2); err != nil {
- return nil, err
- }
- d.Port = (int(port[0]) << 8) | int(port[1])
- return d, nil
- }
- func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
-
- var addrType uint8
- var addrBody []byte
- var addrPort uint16
- switch {
- case addr == nil:
- addrType = ipv4Address
- addrBody = []byte{0, 0, 0, 0}
- addrPort = 0
- case addr.FQDN != "":
- addrType = fqdnAddress
- addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
- addrPort = uint16(addr.Port)
- case addr.IP.To4() != nil:
- addrType = ipv4Address
- addrBody = []byte(addr.IP.To4())
- addrPort = uint16(addr.Port)
- case addr.IP.To16() != nil:
- addrType = ipv6Address
- addrBody = []byte(addr.IP.To16())
- addrPort = uint16(addr.Port)
- default:
- return fmt.Errorf("Failed to format address: %v", addr)
- }
-
- msg := make([]byte, 6+len(addrBody))
- msg[0] = socks5Version
- msg[1] = resp
- msg[2] = 0
- msg[3] = addrType
- copy(msg[4:], addrBody)
- msg[4+len(addrBody)] = byte(addrPort >> 8)
- msg[4+len(addrBody)+1] = byte(addrPort & 0xff)
-
- _, err := w.Write(msg)
- return err
- }
- type closeWriter interface {
- CloseWrite() error
- }
- func proxy(dst io.Writer, src io.Reader, errCh chan error) {
- _, err := io.Copy(dst, src)
- if tcpConn, ok := dst.(closeWriter); ok {
- tcpConn.CloseWrite()
- }
- errCh <- err
- }
|