|
@@ -89,13 +89,15 @@ func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy {
|
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
url := ctx.Value("url").(string)
|
|
|
host := getHostFromAddr(ctx.Value("host").(string))
|
|
|
- return rp.CreateConnection(host, url)
|
|
|
+ remote := ctx.Value("remote").(string)
|
|
|
+ return rp.CreateConnection(host, url, remote)
|
|
|
},
|
|
|
},
|
|
|
WebSocketDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
url := ctx.Value("url").(string)
|
|
|
host := getHostFromAddr(ctx.Value("host").(string))
|
|
|
- return rp.CreateConnection(host, url)
|
|
|
+ remote := ctx.Value("remote").(string)
|
|
|
+ return rp.CreateConnection(host, url, remote)
|
|
|
},
|
|
|
BufferPool: newWrapPool(),
|
|
|
ErrorLog: log.New(newWrapLogger(), "", 0),
|
|
@@ -138,12 +140,12 @@ func (rp *HttpReverseProxy) GetHeaders(domain string, location string) (headers
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (net.Conn, error) {
|
|
|
+func (rp *HttpReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) {
|
|
|
vr, ok := rp.getVhost(domain, location)
|
|
|
if ok {
|
|
|
fn := vr.payload.(*VhostRouteConfig).CreateConnFn
|
|
|
if fn != nil {
|
|
|
- return fn()
|
|
|
+ return fn(remoteAddr)
|
|
|
}
|
|
|
}
|
|
|
return nil, fmt.Errorf("%v: %s %s", ErrNoDomain, domain, location)
|