// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package orm

import (
	"fmt"
	"strings"
	"time"
)

// table info struct.
type dbTable struct {
	id    int
	index string
	name  string
	names []string
	sel   bool
	inner bool
	mi    *modelInfo
	fi    *fieldInfo
	jtl   *dbTable
}

// tables collection struct, contains some tables.
type dbTables struct {
	tablesM map[string]*dbTable
	tables  []*dbTable
	mi      *modelInfo
	base    dbBaser
	skipEnd bool
}

// set table info to collection.
// if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
	name := strings.Join(names, ExprSep)
	if j, ok := t.tablesM[name]; ok {
		j.name = name
		j.mi = mi
		j.fi = fi
		j.inner = inner
	} else {
		i := len(t.tables) + 1
		jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
		t.tablesM[name] = jt
		t.tables = append(t.tables, jt)
	}
	return t.tablesM[name]
}

// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
	name := strings.Join(names, ExprSep)
	if _, ok := t.tablesM[name]; ok == false {
		i := len(t.tables) + 1
		jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
		t.tablesM[name] = jt
		t.tables = append(t.tables, jt)
		return jt, true
	}
	return t.tablesM[name], false
}

// get table info in collection.
func (t *dbTables) get(name string) (*dbTable, bool) {
	j, ok := t.tablesM[name]
	return j, ok
}

// get related fields info in recursive depth loop.
// loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
	if depth < 0 || fi.fieldType == RelManyToMany {
		return related
	}

	if prefix == "" {
		prefix = fi.name
	} else {
		prefix = prefix + ExprSep + fi.name
	}
	related = append(related, prefix)

	depth--
	for _, fi := range fi.relModelInfo.fields.fieldsRel {
		related = t.loopDepth(depth, prefix, fi, related)
	}

	return related
}

// parse related fields.
func (t *dbTables) parseRelated(rels []string, depth int) {

	relsNum := len(rels)
	related := make([]string, relsNum)
	copy(related, rels)

	relDepth := depth

	if relsNum != 0 {
		relDepth = 0
	}

	relDepth--
	for _, fi := range t.mi.fields.fieldsRel {
		related = t.loopDepth(relDepth, "", fi, related)
	}

	for i, s := range related {
		var (
			exs    = strings.Split(s, ExprSep)
			names  = make([]string, 0, len(exs))
			mmi    = t.mi
			cancel = true
			jtl    *dbTable
		)

		inner := true

		for _, ex := range exs {
			if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
				names = append(names, fi.name)
				mmi = fi.relModelInfo

				if fi.null || t.skipEnd {
					inner = false
				}

				jt := t.set(names, mmi, fi, inner)
				jt.jtl = jtl

				if fi.reverse {
					cancel = false
				}

				if cancel {
					jt.sel = depth > 0

					if i < relsNum {
						jt.sel = true
					}
				}

				jtl = jt

			} else {
				panic(fmt.Errorf("unknown model/table name `%s`", ex))
			}
		}
	}
}

// generate join string.
func (t *dbTables) getJoinSQL() (join string) {
	Q := t.base.TableQuote()

	for _, jt := range t.tables {
		if jt.inner {
			join += "INNER JOIN "
		} else {
			join += "LEFT OUTER JOIN "
		}
		var (
			table  string
			t1, t2 string
			c1, c2 string
		)
		t1 = "T0"
		if jt.jtl != nil {
			t1 = jt.jtl.index
		}
		t2 = jt.index
		table = jt.mi.table

		switch {
		case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
			c1 = jt.fi.mi.fields.pk.column
			for _, ffi := range jt.mi.fields.fieldsRel {
				if jt.fi.mi == ffi.relModelInfo {
					c2 = ffi.column
					break
				}
			}
		default:
			c1 = jt.fi.column
			c2 = jt.fi.relModelInfo.fields.pk.column

			if jt.fi.reverse {
				c1 = jt.mi.fields.pk.column
				c2 = jt.fi.reverseFieldInfo.column
			}
		}

		join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
			t2, Q, c2, Q, t1, Q, c1, Q)
	}
	return
}

// parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
	var (
		jtl *dbTable
		fi  *fieldInfo
		fiN *fieldInfo
		mmi = mi
	)

	num := len(exprs) - 1
	var names []string

	inner := true

loopFor:
	for i, ex := range exprs {

		var ok, okN bool

		if fiN != nil {
			fi = fiN
			ok = true
			fiN = nil
		}

		if i == 0 {
			fi, ok = mmi.fields.GetByAny(ex)
		}

		_ = okN

		if ok {

			isRel := fi.rel || fi.reverse

			names = append(names, fi.name)

			switch {
			case fi.rel:
				mmi = fi.relModelInfo
				if fi.fieldType == RelManyToMany {
					mmi = fi.relThroughModelInfo
				}
			case fi.reverse:
				mmi = fi.reverseFieldInfo.mi
			}

			if i < num {
				fiN, okN = mmi.fields.GetByAny(exprs[i+1])
			}

			if isRel && (fi.mi.isThrough == false || num != i) {
				if fi.null || t.skipEnd {
					inner = false
				}

				if t.skipEnd && okN || !t.skipEnd {
					if t.skipEnd && okN && fiN.pk {
						goto loopEnd
					}

					jt, _ := t.add(names, mmi, fi, inner)
					jt.jtl = jtl
					jtl = jt
				}

			}

			if num != i {
				continue
			}

		loopEnd:

			if i == 0 || jtl == nil {
				index = "T0"
			} else {
				index = jtl.index
			}

			info = fi

			if jtl == nil {
				name = fi.name
			} else {
				name = jtl.name + ExprSep + fi.name
			}

			switch {
			case fi.rel:

			case fi.reverse:
				switch fi.reverseFieldInfo.fieldType {
				case RelOneToOne, RelForeignKey:
					index = jtl.index
					info = fi.reverseFieldInfo.mi.fields.pk
					name = info.name
				}
			}

			break loopFor

		} else {
			index = ""
			name = ""
			info = nil
			success = false
			return
		}
	}

	success = index != "" && info != nil
	return
}

// generate condition sql.
func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
	if cond == nil || cond.IsEmpty() {
		return
	}

	Q := t.base.TableQuote()

	mi := t.mi

	for i, p := range cond.params {
		if i > 0 {
			if p.isOr {
				where += "OR "
			} else {
				where += "AND "
			}
		}
		if p.isNot {
			where += "NOT "
		}
		if p.isCond {
			w, ps := t.getCondSQL(p.cond, true, tz)
			if w != "" {
				w = fmt.Sprintf("( %s) ", w)
			}
			where += w
			params = append(params, ps...)
		} else {
			exprs := p.exprs

			num := len(exprs) - 1
			operator := ""
			if operators[exprs[num]] {
				operator = exprs[num]
				exprs = exprs[:num]
			}

			index, _, fi, suc := t.parseExprs(mi, exprs)
			if suc == false {
				panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
			}

			if operator == "" {
				operator = "exact"
			}

			operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)

			leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
			t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)

			where += fmt.Sprintf("%s %s ", leftCol, operSQL)
			params = append(params, args...)

		}
	}

	if sub == false && where != "" {
		where = "WHERE " + where
	}

	return
}

// generate group sql.
func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
	if len(groups) == 0 {
		return
	}

	Q := t.base.TableQuote()

	groupSqls := make([]string, 0, len(groups))
	for _, group := range groups {
		exprs := strings.Split(group, ExprSep)

		index, _, fi, suc := t.parseExprs(t.mi, exprs)
		if suc == false {
			panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
		}

		groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
	}

	groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
	return
}

// generate order sql.
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
	if len(orders) == 0 {
		return
	}

	Q := t.base.TableQuote()

	orderSqls := make([]string, 0, len(orders))
	for _, order := range orders {
		asc := "ASC"
		if order[0] == '-' {
			asc = "DESC"
			order = order[1:]
		}
		exprs := strings.Split(order, ExprSep)

		index, _, fi, suc := t.parseExprs(t.mi, exprs)
		if suc == false {
			panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
		}

		orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
	}

	orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
	return
}

// generate limit sql.
func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
	if limit == 0 {
		limit = int64(DefaultRowsLimit)
	}
	if limit < 0 {
		// no limit
		if offset > 0 {
			maxLimit := t.base.MaxLimit()
			if maxLimit == 0 {
				limits = fmt.Sprintf("OFFSET %d", offset)
			} else {
				limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
			}
		}
	} else if offset <= 0 {
		limits = fmt.Sprintf("LIMIT %d", limit)
	} else {
		limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
	}
	return
}

// crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
	tables := &dbTables{}
	tables.tablesM = make(map[string]*dbTable)
	tables.mi = mi
	tables.base = base
	return tables
}