util.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. package validator
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strconv"
  6. "strings"
  7. )
  8. const (
  9. dash = "-"
  10. blank = ""
  11. namespaceSeparator = "."
  12. leftBracket = "["
  13. rightBracket = "]"
  14. restrictedTagChars = ".[],|=+()`~!@#$%^&*\\\"/?<>{}"
  15. restrictedAliasErr = "Alias '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
  16. restrictedTagErr = "Tag '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
  17. )
  18. var (
  19. restrictedTags = map[string]*struct{}{
  20. diveTag: emptyStructPtr,
  21. existsTag: emptyStructPtr,
  22. structOnlyTag: emptyStructPtr,
  23. omitempty: emptyStructPtr,
  24. skipValidationTag: emptyStructPtr,
  25. utf8HexComma: emptyStructPtr,
  26. utf8Pipe: emptyStructPtr,
  27. noStructLevelTag: emptyStructPtr,
  28. }
  29. )
  30. // ExtractType gets the actual underlying type of field value.
  31. // It will dive into pointers, customTypes and return you the
  32. // underlying value and it's kind.
  33. // it is exposed for use within you Custom Functions
  34. func (v *Validate) ExtractType(current reflect.Value) (reflect.Value, reflect.Kind) {
  35. switch current.Kind() {
  36. case reflect.Ptr:
  37. if current.IsNil() {
  38. return current, reflect.Ptr
  39. }
  40. return v.ExtractType(current.Elem())
  41. case reflect.Interface:
  42. if current.IsNil() {
  43. return current, reflect.Interface
  44. }
  45. return v.ExtractType(current.Elem())
  46. case reflect.Invalid:
  47. return current, reflect.Invalid
  48. default:
  49. if v.hasCustomFuncs {
  50. // fmt.Println("Type", current.Type())
  51. if fn, ok := v.customTypeFuncs[current.Type()]; ok {
  52. // fmt.Println("OK")
  53. return v.ExtractType(reflect.ValueOf(fn(current)))
  54. }
  55. // fmt.Println("NOT OK")
  56. }
  57. return current, current.Kind()
  58. }
  59. }
  60. // GetStructFieldOK traverses a struct to retrieve a specific field denoted by the provided namespace and
  61. // returns the field, field kind and whether is was successful in retrieving the field at all.
  62. // NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field
  63. // could not be retrived because it didnt exist.
  64. func (v *Validate) GetStructFieldOK(current reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) {
  65. current, kind := v.ExtractType(current)
  66. if kind == reflect.Invalid {
  67. return current, kind, false
  68. }
  69. if namespace == blank {
  70. return current, kind, true
  71. }
  72. switch kind {
  73. case reflect.Ptr, reflect.Interface:
  74. return current, kind, false
  75. case reflect.Struct:
  76. typ := current.Type()
  77. fld := namespace
  78. ns := namespace
  79. if typ != timeType && typ != timePtrType {
  80. idx := strings.Index(namespace, namespaceSeparator)
  81. if idx != -1 {
  82. fld = namespace[:idx]
  83. ns = namespace[idx+1:]
  84. } else {
  85. ns = blank
  86. idx = len(namespace)
  87. }
  88. bracketIdx := strings.Index(fld, leftBracket)
  89. if bracketIdx != -1 {
  90. fld = fld[:bracketIdx]
  91. ns = namespace[bracketIdx:]
  92. }
  93. current = current.FieldByName(fld)
  94. return v.GetStructFieldOK(current, ns)
  95. }
  96. case reflect.Array, reflect.Slice:
  97. idx := strings.Index(namespace, leftBracket)
  98. idx2 := strings.Index(namespace, rightBracket)
  99. arrIdx, _ := strconv.Atoi(namespace[idx+1 : idx2])
  100. if arrIdx >= current.Len() {
  101. return current, kind, false
  102. }
  103. startIdx := idx2 + 1
  104. if startIdx < len(namespace) {
  105. if namespace[startIdx:startIdx+1] == namespaceSeparator {
  106. startIdx++
  107. }
  108. }
  109. return v.GetStructFieldOK(current.Index(arrIdx), namespace[startIdx:])
  110. case reflect.Map:
  111. idx := strings.Index(namespace, leftBracket) + 1
  112. idx2 := strings.Index(namespace, rightBracket)
  113. endIdx := idx2
  114. if endIdx+1 < len(namespace) {
  115. if namespace[endIdx+1:endIdx+2] == namespaceSeparator {
  116. endIdx++
  117. }
  118. }
  119. key := namespace[idx:idx2]
  120. switch current.Type().Key().Kind() {
  121. case reflect.Int:
  122. i, _ := strconv.Atoi(key)
  123. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:])
  124. case reflect.Int8:
  125. i, _ := strconv.ParseInt(key, 10, 8)
  126. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(int8(i))), namespace[endIdx+1:])
  127. case reflect.Int16:
  128. i, _ := strconv.ParseInt(key, 10, 16)
  129. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(int16(i))), namespace[endIdx+1:])
  130. case reflect.Int32:
  131. i, _ := strconv.ParseInt(key, 10, 32)
  132. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(int32(i))), namespace[endIdx+1:])
  133. case reflect.Int64:
  134. i, _ := strconv.ParseInt(key, 10, 64)
  135. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:])
  136. case reflect.Uint:
  137. i, _ := strconv.ParseUint(key, 10, 0)
  138. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint(i))), namespace[endIdx+1:])
  139. case reflect.Uint8:
  140. i, _ := strconv.ParseUint(key, 10, 8)
  141. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint8(i))), namespace[endIdx+1:])
  142. case reflect.Uint16:
  143. i, _ := strconv.ParseUint(key, 10, 16)
  144. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint16(i))), namespace[endIdx+1:])
  145. case reflect.Uint32:
  146. i, _ := strconv.ParseUint(key, 10, 32)
  147. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint32(i))), namespace[endIdx+1:])
  148. case reflect.Uint64:
  149. i, _ := strconv.ParseUint(key, 10, 64)
  150. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:])
  151. case reflect.Float32:
  152. f, _ := strconv.ParseFloat(key, 32)
  153. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(float32(f))), namespace[endIdx+1:])
  154. case reflect.Float64:
  155. f, _ := strconv.ParseFloat(key, 64)
  156. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(f)), namespace[endIdx+1:])
  157. case reflect.Bool:
  158. b, _ := strconv.ParseBool(key)
  159. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(b)), namespace[endIdx+1:])
  160. // reflect.Type = string
  161. default:
  162. return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(key)), namespace[endIdx+1:])
  163. }
  164. }
  165. // if got here there was more namespace, cannot go any deeper
  166. panic("Invalid field namespace")
  167. }
  168. // asInt retuns the parameter as a int64
  169. // or panics if it can't convert
  170. func asInt(param string) int64 {
  171. i, err := strconv.ParseInt(param, 0, 64)
  172. panicIf(err)
  173. return i
  174. }
  175. // asUint returns the parameter as a uint64
  176. // or panics if it can't convert
  177. func asUint(param string) uint64 {
  178. i, err := strconv.ParseUint(param, 0, 64)
  179. panicIf(err)
  180. return i
  181. }
  182. // asFloat returns the parameter as a float64
  183. // or panics if it can't convert
  184. func asFloat(param string) float64 {
  185. i, err := strconv.ParseFloat(param, 64)
  186. panicIf(err)
  187. return i
  188. }
  189. func panicIf(err error) {
  190. if err != nil {
  191. panic(err.Error())
  192. }
  193. }
  194. func (v *Validate) parseStruct(current reflect.Value, sName string) *cachedStruct {
  195. typ := current.Type()
  196. s := &cachedStruct{Name: sName, fields: map[int]cachedField{}}
  197. numFields := current.NumField()
  198. var fld reflect.StructField
  199. var tag string
  200. var customName string
  201. for i := 0; i < numFields; i++ {
  202. fld = typ.Field(i)
  203. if fld.PkgPath != blank {
  204. continue
  205. }
  206. tag = fld.Tag.Get(v.tagName)
  207. if tag == skipValidationTag {
  208. continue
  209. }
  210. customName = fld.Name
  211. if v.fieldNameTag != blank {
  212. name := strings.SplitN(fld.Tag.Get(v.fieldNameTag), ",", 2)[0]
  213. // dash check is for json "-" (aka skipValidationTag) means don't output in json
  214. if name != "" && name != skipValidationTag {
  215. customName = name
  216. }
  217. }
  218. cTag, ok := v.tagCache.Get(tag)
  219. if !ok {
  220. cTag = v.parseTags(tag, fld.Name)
  221. }
  222. s.fields[i] = cachedField{Idx: i, Name: fld.Name, AltName: customName, CachedTag: cTag}
  223. }
  224. v.structCache.Set(typ, s)
  225. return s
  226. }
  227. func (v *Validate) parseTags(tag, fieldName string) *cachedTag {
  228. cTag := &cachedTag{tag: tag}
  229. v.parseTagsRecursive(cTag, tag, fieldName, blank, false)
  230. v.tagCache.Set(tag, cTag)
  231. return cTag
  232. }
  233. func (v *Validate) parseTagsRecursive(cTag *cachedTag, tag, fieldName, alias string, isAlias bool) bool {
  234. if tag == blank {
  235. return true
  236. }
  237. for _, t := range strings.Split(tag, tagSeparator) {
  238. if v.hasAliasValidators {
  239. // check map for alias and process new tags, otherwise process as usual
  240. if tagsVal, ok := v.aliasValidators[t]; ok {
  241. leave := v.parseTagsRecursive(cTag, tagsVal, fieldName, t, true)
  242. if leave {
  243. return leave
  244. }
  245. continue
  246. }
  247. }
  248. switch t {
  249. case diveTag:
  250. cTag.diveTag = tag
  251. tVals := &tagVals{tagVals: [][]string{{t}}}
  252. cTag.tags = append(cTag.tags, tVals)
  253. return true
  254. case omitempty:
  255. cTag.isOmitEmpty = true
  256. case structOnlyTag:
  257. cTag.isStructOnly = true
  258. case noStructLevelTag:
  259. cTag.isNoStructLevel = true
  260. }
  261. // if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
  262. orVals := strings.Split(t, orSeparator)
  263. tagVal := &tagVals{isAlias: isAlias, isOrVal: len(orVals) > 1, tagVals: make([][]string, len(orVals))}
  264. cTag.tags = append(cTag.tags, tagVal)
  265. var key string
  266. var param string
  267. for i, val := range orVals {
  268. vals := strings.SplitN(val, tagKeySeparator, 2)
  269. key = vals[0]
  270. tagVal.tag = key
  271. if isAlias {
  272. tagVal.tag = alias
  273. }
  274. if key == blank {
  275. panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
  276. }
  277. if len(vals) > 1 {
  278. param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
  279. }
  280. tagVal.tagVals[i] = []string{key, param}
  281. }
  282. }
  283. return false
  284. }