package basic

import (
	"fmt"
	"strings"

	"github.com/fatedier/frp/test/e2e/framework"
	"github.com/fatedier/frp/test/e2e/framework/consts"

	. "github.com/onsi/ginkgo"
)

type generalTestConfigures struct {
	server      string
	client      string
	expectError bool
}

func defineClientServerTest(desc string, f *framework.Framework, configures *generalTestConfigures) {
	It(desc, func() {
		serverConf := consts.DefaultServerConfig
		clientConf := consts.DefaultClientConfig

		serverConf += fmt.Sprintf(`
				%s
				`, configures.server)

		clientConf += fmt.Sprintf(`
				%s

				[tcp]
				type = tcp
				local_port = {{ .%s }}
				remote_port = {{ .%s }}

				[udp]
				type = udp
				local_port = {{ .%s }}
				remote_port = {{ .%s }}
				`, configures.client,
			framework.TCPEchoServerPort, framework.GenPortName("TCP"),
			framework.UDPEchoServerPort, framework.GenPortName("UDP"),
		)

		f.RunProcesses([]string{serverConf}, []string{clientConf})

		if !configures.expectError {
			framework.ExpectTCPRequest(f.UsedPorts[framework.GenPortName("TCP")],
				[]byte(consts.TestString), []byte(consts.TestString), connTimeout, "tcp proxy")
			framework.ExpectUDPRequest(f.UsedPorts[framework.GenPortName("UDP")],
				[]byte(consts.TestString), []byte(consts.TestString), connTimeout, "udp proxy")
		} else {
			framework.ExpectTCPRequestError(f.UsedPorts[framework.GenPortName("TCP")],
				[]byte(consts.TestString), connTimeout, "tcp proxy")
			framework.ExpectUDPRequestError(f.UsedPorts[framework.GenPortName("UDP")],
				[]byte(consts.TestString), connTimeout, "udp proxy")
		}
	})
}

var _ = Describe("[Feature: Client-Server]", func() {
	f := framework.NewDefaultFramework()

	Describe("Protocol", func() {
		supportProtocols := []string{"tcp", "kcp", "websocket"}
		for _, protocol := range supportProtocols {
			configures := &generalTestConfigures{
				server: fmt.Sprintf(`
				kcp_bind_port = {{ .%s }}
				protocol = %s"
				`, consts.PortServerName, protocol),
				client: "protocol = " + protocol,
			}
			defineClientServerTest(protocol, f, configures)
		}
	})

	Describe("Authentication", func() {
		defineClientServerTest("Token Correct", f, &generalTestConfigures{
			server: "token = 123456",
			client: "token = 123456",
		})

		defineClientServerTest("Token Incorrect", f, &generalTestConfigures{
			server:      "token = 123456",
			client:      "token = invalid",
			expectError: true,
		})
	})

	Describe("TLS", func() {
		supportProtocols := []string{"tcp", "kcp", "websocket"}
		for _, protocol := range supportProtocols {
			tmp := protocol
			defineClientServerTest("TLS over "+strings.ToUpper(tmp), f, &generalTestConfigures{
				server: fmt.Sprintf(`
				kcp_bind_port = {{ .%s }}
				protocol = %s
				`, consts.PortServerName, protocol),
				client: fmt.Sprintf(`tls_enable = true
				protocol = %s
				`, protocol),
			})
		}

		defineClientServerTest("enable tls_only, client with TLS", f, &generalTestConfigures{
			server: "tls_only = true",
			client: "tls_enable = true",
		})
		defineClientServerTest("enable tls_only, client without TLS", f, &generalTestConfigures{
			server:      "tls_only = true",
			expectError: true,
		})
	})
})