https.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. // Copyright 2016 fatedier, fatedier@gmail.com
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package vhost
  15. import (
  16. "fmt"
  17. "io"
  18. "net"
  19. "strings"
  20. "time"
  21. gnet "github.com/fatedier/golib/net"
  22. "github.com/fatedier/golib/pool"
  23. )
  24. const (
  25. typeClientHello uint8 = 1 // Type client hello
  26. )
  27. // TLS extension numbers
  28. const (
  29. extensionServerName uint16 = 0
  30. extensionStatusRequest uint16 = 5
  31. extensionSupportedCurves uint16 = 10
  32. extensionSupportedPoints uint16 = 11
  33. extensionSignatureAlgorithms uint16 = 13
  34. extensionALPN uint16 = 16
  35. extensionSCT uint16 = 18
  36. extensionSessionTicket uint16 = 35
  37. extensionNextProtoNeg uint16 = 13172 // not IANA assigned
  38. extensionRenegotiationInfo uint16 = 0xff01
  39. )
  40. type HTTPSMuxer struct {
  41. *Muxer
  42. }
  43. func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
  44. mux, err := NewMuxer(listener, GetHTTPSHostname, nil, nil, nil, timeout)
  45. return &HTTPSMuxer{mux}, err
  46. }
  47. func readHandshake(rd io.Reader) (host string, err error) {
  48. data := pool.GetBuf(1024)
  49. origin := data
  50. defer pool.PutBuf(origin)
  51. _, err = io.ReadFull(rd, data[:47])
  52. if err != nil {
  53. return
  54. }
  55. length, err := rd.Read(data[47:])
  56. if err != nil {
  57. return
  58. }
  59. length += 47
  60. data = data[:length]
  61. if uint8(data[5]) != typeClientHello {
  62. err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5]))
  63. return
  64. }
  65. // session
  66. sessionIDLen := int(data[43])
  67. if sessionIDLen > 32 || len(data) < 44+sessionIDLen {
  68. err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIDLen)
  69. return
  70. }
  71. data = data[44+sessionIDLen:]
  72. if len(data) < 2 {
  73. err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data))
  74. return
  75. }
  76. // cipher suite numbers
  77. cipherSuiteLen := int(data[0])<<8 | int(data[1])
  78. if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
  79. err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data))
  80. return
  81. }
  82. data = data[2+cipherSuiteLen:]
  83. if len(data) < 1 {
  84. err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
  85. return
  86. }
  87. // compression method
  88. compressionMethodsLen := int(data[0])
  89. if len(data) < 1+compressionMethodsLen {
  90. err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
  91. return
  92. }
  93. data = data[1+compressionMethodsLen:]
  94. if len(data) == 0 {
  95. // ClientHello is optionally followed by extension data
  96. err = fmt.Errorf("readHandshake: there is no extension data to get servername")
  97. return
  98. }
  99. if len(data) < 2 {
  100. err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short", len(data))
  101. return
  102. }
  103. extensionsLength := int(data[0])<<8 | int(data[1])
  104. data = data[2:]
  105. if extensionsLength != len(data) {
  106. err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
  107. return
  108. }
  109. for len(data) != 0 {
  110. if len(data) < 4 {
  111. err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data))
  112. return
  113. }
  114. extension := uint16(data[0])<<8 | uint16(data[1])
  115. length := int(data[2])<<8 | int(data[3])
  116. data = data[4:]
  117. if len(data) < length {
  118. err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length)
  119. return
  120. }
  121. switch extension {
  122. case extensionRenegotiationInfo:
  123. if length != 1 || data[0] != 0 {
  124. err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
  125. return
  126. }
  127. case extensionNextProtoNeg:
  128. case extensionStatusRequest:
  129. case extensionServerName:
  130. d := data[:length]
  131. if len(d) < 2 {
  132. err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d))
  133. return
  134. }
  135. namesLen := int(d[0])<<8 | int(d[1])
  136. d = d[2:]
  137. if len(d) != namesLen {
  138. err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
  139. return
  140. }
  141. for len(d) > 0 {
  142. if len(d) < 3 {
  143. err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d))
  144. return
  145. }
  146. nameType := d[0]
  147. nameLen := int(d[1])<<8 | int(d[2])
  148. d = d[3:]
  149. if len(d) < nameLen {
  150. err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
  151. return
  152. }
  153. if nameType == 0 {
  154. serverName := string(d[:nameLen])
  155. host = strings.TrimSpace(serverName)
  156. return host, nil
  157. }
  158. d = d[nameLen:]
  159. }
  160. }
  161. data = data[length:]
  162. }
  163. err = fmt.Errorf("Unknown error")
  164. return
  165. }
  166. func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
  167. reqInfoMap := make(map[string]string, 0)
  168. sc, rd := gnet.NewSharedConn(c)
  169. host, err := readHandshake(rd)
  170. if err != nil {
  171. return nil, reqInfoMap, err
  172. }
  173. reqInfoMap["Host"] = host
  174. reqInfoMap["Scheme"] = "https"
  175. return sc, reqInfoMap, nil
  176. }