bash_completions.go 17 KB


  1. package cobra
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "os"
  7. "sort"
  8. "strings"
  9. "github.com/spf13/pflag"
  10. )
  11. // Annotations for Bash completion.
  12. const (
  13. BashCompFilenameExt = "cobra_annotation_bash_completion_filename_extensions"
  14. BashCompCustom = "cobra_annotation_bash_completion_custom"
  15. BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag"
  16. BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir"
  17. )
  18. func writePreamble(buf *bytes.Buffer, name string) {
  19. buf.WriteString(fmt.Sprintf("# bash completion for %-36s -*- shell-script -*-\n", name))
  20. buf.WriteString(fmt.Sprintf(`
  21. __%[1]s_debug()
  22. {
  23. if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
  24. echo "$*" >> "${BASH_COMP_DEBUG_FILE}"
  25. fi
  26. }
  27. # Homebrew on Macs have version 1.3 of bash-completion which doesn't include
  28. # _init_completion. This is a very minimal version of that function.
  29. __%[1]s_init_completion()
  30. {
  31. COMPREPLY=()
  32. _get_comp_words_by_ref "$@" cur prev words cword
  33. }
  34. __%[1]s_index_of_word()
  35. {
  36. local w word=$1
  37. shift
  38. index=0
  39. for w in "$@"; do
  40. [[ $w = "$word" ]] && return
  41. index=$((index+1))
  42. done
  43. index=-1
  44. }
  45. __%[1]s_contains_word()
  46. {
  47. local w word=$1; shift
  48. for w in "$@"; do
  49. [[ $w = "$word" ]] && return
  50. done
  51. return 1
  52. }
  53. __%[1]s_handle_reply()
  54. {
  55. __%[1]s_debug "${FUNCNAME[0]}"
  56. case $cur in
  57. -*)
  58. if [[ $(type -t compopt) = "builtin" ]]; then
  59. compopt -o nospace
  60. fi
  61. local allflags
  62. if [ ${#must_have_one_flag[@]} -ne 0 ]; then
  63. allflags=("${must_have_one_flag[@]}")
  64. else
  65. allflags=("${flags[*]} ${two_word_flags[*]}")
  66. fi
  67. COMPREPLY=( $(compgen -W "${allflags[*]}" -- "$cur") )
  68. if [[ $(type -t compopt) = "builtin" ]]; then
  69. [[ "${COMPREPLY[0]}" == *= ]] || compopt +o nospace
  70. fi
  71. # complete after --flag=abc
  72. if [[ $cur == *=* ]]; then
  73. if [[ $(type -t compopt) = "builtin" ]]; then
  74. compopt +o nospace
  75. fi
  76. local index flag
  77. flag="${cur%%=*}"
  78. __%[1]s_index_of_word "${flag}" "${flags_with_completion[@]}"
  79. COMPREPLY=()
  80. if [[ ${index} -ge 0 ]]; then
  81. PREFIX=""
  82. cur="${cur#*=}"
  83. ${flags_completion[${index}]}
  84. if [ -n "${ZSH_VERSION}" ]; then
  85. # zsh completion needs --flag= prefix
  86. eval "COMPREPLY=( \"\${COMPREPLY[@]/#/${flag}=}\" )"
  87. fi
  88. fi
  89. fi
  90. return 0;
  91. ;;
  92. esac
  93. # check if we are handling a flag with special work handling
  94. local index
  95. __%[1]s_index_of_word "${prev}" "${flags_with_completion[@]}"
  96. if [[ ${index} -ge 0 ]]; then
  97. ${flags_completion[${index}]}
  98. return
  99. fi
  100. # we are parsing a flag and don't have a special handler, no completion
  101. if [[ ${cur} != "${words[cword]}" ]]; then
  102. return
  103. fi
  104. local completions
  105. completions=("${commands[@]}")
  106. if [[ ${#must_have_one_noun[@]} -ne 0 ]]; then
  107. completions=("${must_have_one_noun[@]}")
  108. fi
  109. if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then
  110. completions+=("${must_have_one_flag[@]}")
  111. fi
  112. COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") )
  113. if [[ ${#COMPREPLY[@]} -eq 0 && ${#noun_aliases[@]} -gt 0 && ${#must_have_one_noun[@]} -ne 0 ]]; then
  114. COMPREPLY=( $(compgen -W "${noun_aliases[*]}" -- "$cur") )
  115. fi
  116. if [[ ${#COMPREPLY[@]} -eq 0 ]]; then
  117. declare -F __custom_func >/dev/null && __custom_func
  118. fi
  119. # available in bash-completion >= 2, not always present on macOS
  120. if declare -F __ltrim_colon_completions >/dev/null; then
  121. __ltrim_colon_completions "$cur"
  122. fi
  123. # If there is only 1 completion and it is a flag with an = it will be completed
  124. # but we don't want a space after the =
  125. if [[ "${#COMPREPLY[@]}" -eq "1" ]] && [[ $(type -t compopt) = "builtin" ]] && [[ "${COMPREPLY[0]}" == --*= ]]; then
  126. compopt -o nospace
  127. fi
  128. }
  129. # The arguments should be in the form "ext1|ext2|extn"
  130. __%[1]s_handle_filename_extension_flag()
  131. {
  132. local ext="$1"
  133. _filedir "@(${ext})"
  134. }
  135. __%[1]s_handle_subdirs_in_dir_flag()
  136. {
  137. local dir="$1"
  138. pushd "${dir}" >/dev/null 2>&1 && _filedir -d && popd >/dev/null 2>&1
  139. }
  140. __%[1]s_handle_flag()
  141. {
  142. __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
  143. # if a command required a flag, and we found it, unset must_have_one_flag()
  144. local flagname=${words[c]}
  145. local flagvalue
  146. # if the word contained an =
  147. if [[ ${words[c]} == *"="* ]]; then
  148. flagvalue=${flagname#*=} # take in as flagvalue after the =
  149. flagname=${flagname%%=*} # strip everything after the =
  150. flagname="${flagname}=" # but put the = back
  151. fi
  152. __%[1]s_debug "${FUNCNAME[0]}: looking for ${flagname}"
  153. if __%[1]s_contains_word "${flagname}" "${must_have_one_flag[@]}"; then
  154. must_have_one_flag=()
  155. fi
  156. # if you set a flag which only applies to this command, don't show subcommands
  157. if __%[1]s_contains_word "${flagname}" "${local_nonpersistent_flags[@]}"; then
  158. commands=()
  159. fi
  160. # keep flag value with flagname as flaghash
  161. # flaghash variable is an associative array which is only supported in bash > 3.
  162. if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then
  163. if [ -n "${flagvalue}" ] ; then
  164. flaghash[${flagname}]=${flagvalue}
  165. elif [ -n "${words[ $((c+1)) ]}" ] ; then
  166. flaghash[${flagname}]=${words[ $((c+1)) ]}
  167. else
  168. flaghash[${flagname}]="true" # pad "true" for bool flag
  169. fi
  170. fi
  171. # skip the argument to a two word flag
  172. if __%[1]s_contains_word "${words[c]}" "${two_word_flags[@]}"; then
  173. c=$((c+1))
  174. # if we are looking for a flags value, don't show commands
  175. if [[ $c -eq $cword ]]; then
  176. commands=()
  177. fi
  178. fi
  179. c=$((c+1))
  180. }
  181. __%[1]s_handle_noun()
  182. {
  183. __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
  184. if __%[1]s_contains_word "${words[c]}" "${must_have_one_noun[@]}"; then
  185. must_have_one_noun=()
  186. elif __%[1]s_contains_word "${words[c]}" "${noun_aliases[@]}"; then
  187. must_have_one_noun=()
  188. fi
  189. nouns+=("${words[c]}")
  190. c=$((c+1))
  191. }
  192. __%[1]s_handle_command()
  193. {
  194. __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
  195. local next_command
  196. if [[ -n ${last_command} ]]; then
  197. next_command="_${last_command}_${words[c]//:/__}"
  198. else
  199. if [[ $c -eq 0 ]]; then
  200. next_command="_%[1]s_root_command"
  201. else
  202. next_command="_${words[c]//:/__}"
  203. fi
  204. fi
  205. c=$((c+1))
  206. __%[1]s_debug "${FUNCNAME[0]}: looking for ${next_command}"
  207. declare -F "$next_command" >/dev/null && $next_command
  208. }
  209. __%[1]s_handle_word()
  210. {
  211. if [[ $c -ge $cword ]]; then
  212. __%[1]s_handle_reply
  213. return
  214. fi
  215. __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
  216. if [[ "${words[c]}" == -* ]]; then
  217. __%[1]s_handle_flag
  218. elif __%[1]s_contains_word "${words[c]}" "${commands[@]}"; then
  219. __%[1]s_handle_command
  220. elif [[ $c -eq 0 ]]; then
  221. __%[1]s_handle_command
  222. else
  223. __%[1]s_handle_noun
  224. fi
  225. __%[1]s_handle_word
  226. }
  227. `, name))
  228. }
  229. func writePostscript(buf *bytes.Buffer, name string) {
  230. name = strings.Replace(name, ":", "__", -1)
  231. buf.WriteString(fmt.Sprintf("__start_%s()\n", name))
  232. buf.WriteString(fmt.Sprintf(`{
  233. local cur prev words cword
  234. declare -A flaghash 2>/dev/null || :
  235. if declare -F _init_completion >/dev/null 2>&1; then
  236. _init_completion -s || return
  237. else
  238. __%[1]s_init_completion -n "=" || return
  239. fi
  240. local c=0
  241. local flags=()
  242. local two_word_flags=()
  243. local local_nonpersistent_flags=()
  244. local flags_with_completion=()
  245. local flags_completion=()
  246. local commands=("%[1]s")
  247. local must_have_one_flag=()
  248. local must_have_one_noun=()
  249. local last_command
  250. local nouns=()
  251. __%[1]s_handle_word
  252. }
  253. `, name))
  254. buf.WriteString(fmt.Sprintf(`if [[ $(type -t compopt) = "builtin" ]]; then
  255. complete -o default -F __start_%s %s
  256. else
  257. complete -o default -o nospace -F __start_%s %s
  258. fi
  259. `, name, name, name, name))
  260. buf.WriteString("# ex: ts=4 sw=4 et filetype=sh\n")
  261. }
  262. func writeCommands(buf *bytes.Buffer, cmd *Command) {
  263. buf.WriteString(" commands=()\n")
  264. for _, c := range cmd.Commands() {
  265. if !c.IsAvailableCommand() || c == cmd.helpCommand {
  266. continue
  267. }
  268. buf.WriteString(fmt.Sprintf(" commands+=(%q)\n", c.Name()))
  269. }
  270. buf.WriteString("\n")
  271. }
  272. func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]string, cmd *Command) {
  273. for key, value := range annotations {
  274. switch key {
  275. case BashCompFilenameExt:
  276. buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name))
  277. var ext string
  278. if len(value) > 0 {
  279. ext = fmt.Sprintf("__%s_handle_filename_extension_flag ", cmd.Root().Name()) + strings.Join(value, "|")
  280. } else {
  281. ext = "_filedir"
  282. }
  283. buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", ext))
  284. case BashCompCustom:
  285. buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name))
  286. if len(value) > 0 {
  287. handlers := strings.Join(value, "; ")
  288. buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", handlers))
  289. } else {
  290. buf.WriteString(" flags_completion+=(:)\n")
  291. }
  292. case BashCompSubdirsInDir:
  293. buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name))
  294. var ext string
  295. if len(value) == 1 {
  296. ext = fmt.Sprintf("__%s_handle_subdirs_in_dir_flag ", cmd.Root().Name()) + value[0]
  297. } else {
  298. ext = "_filedir -d"
  299. }
  300. buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", ext))
  301. }
  302. }
  303. }
  304. func writeShortFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) {
  305. name := flag.Shorthand
  306. format := " "
  307. if len(flag.NoOptDefVal) == 0 {
  308. format += "two_word_"
  309. }
  310. format += "flags+=(\"-%s\")\n"
  311. buf.WriteString(fmt.Sprintf(format, name))
  312. writeFlagHandler(buf, "-"+name, flag.Annotations, cmd)
  313. }
  314. func writeFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) {
  315. name := flag.Name
  316. format := " flags+=(\"--%s"
  317. if len(flag.NoOptDefVal) == 0 {
  318. format += "="
  319. }
  320. format += "\")\n"
  321. buf.WriteString(fmt.Sprintf(format, name))
  322. writeFlagHandler(buf, "--"+name, flag.Annotations, cmd)
  323. }
  324. func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) {
  325. name := flag.Name
  326. format := " local_nonpersistent_flags+=(\"--%s"
  327. if len(flag.NoOptDefVal) == 0 {
  328. format += "="
  329. }
  330. format += "\")\n"
  331. buf.WriteString(fmt.Sprintf(format, name))
  332. }
  333. func writeFlags(buf *bytes.Buffer, cmd *Command) {
  334. buf.WriteString(` flags=()
  335. two_word_flags=()
  336. local_nonpersistent_flags=()
  337. flags_with_completion=()
  338. flags_completion=()
  339. `)
  340. localNonPersistentFlags := cmd.LocalNonPersistentFlags()
  341. cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
  342. if nonCompletableFlag(flag) {
  343. return
  344. }
  345. writeFlag(buf, flag, cmd)
  346. if len(flag.Shorthand) > 0 {
  347. writeShortFlag(buf, flag, cmd)
  348. }
  349. if localNonPersistentFlags.Lookup(flag.Name) != nil {
  350. writeLocalNonPersistentFlag(buf, flag)
  351. }
  352. })
  353. cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
  354. if nonCompletableFlag(flag) {
  355. return
  356. }
  357. writeFlag(buf, flag, cmd)
  358. if len(flag.Shorthand) > 0 {
  359. writeShortFlag(buf, flag, cmd)
  360. }
  361. })
  362. buf.WriteString("\n")
  363. }
  364. func writeRequiredFlag(buf *bytes.Buffer, cmd *Command) {
  365. buf.WriteString(" must_have_one_flag=()\n")
  366. flags := cmd.NonInheritedFlags()
  367. flags.VisitAll(func(flag *pflag.Flag) {
  368. if nonCompletableFlag(flag) {
  369. return
  370. }
  371. for key := range flag.Annotations {
  372. switch key {
  373. case BashCompOneRequiredFlag:
  374. format := " must_have_one_flag+=(\"--%s"
  375. if flag.Value.Type() != "bool" {
  376. format += "="
  377. }
  378. format += "\")\n"
  379. buf.WriteString(fmt.Sprintf(format, flag.Name))
  380. if len(flag.Shorthand) > 0 {
  381. buf.WriteString(fmt.Sprintf(" must_have_one_flag+=(\"-%s\")\n", flag.Shorthand))
  382. }
  383. }
  384. }
  385. })
  386. }
  387. func writeRequiredNouns(buf *bytes.Buffer, cmd *Command) {
  388. buf.WriteString(" must_have_one_noun=()\n")
  389. sort.Sort(sort.StringSlice(cmd.ValidArgs))
  390. for _, value := range cmd.ValidArgs {
  391. buf.WriteString(fmt.Sprintf(" must_have_one_noun+=(%q)\n", value))
  392. }
  393. }
  394. func writeArgAliases(buf *bytes.Buffer, cmd *Command) {
  395. buf.WriteString(" noun_aliases=()\n")
  396. sort.Sort(sort.StringSlice(cmd.ArgAliases))
  397. for _, value := range cmd.ArgAliases {
  398. buf.WriteString(fmt.Sprintf(" noun_aliases+=(%q)\n", value))
  399. }
  400. }
  401. func gen(buf *bytes.Buffer, cmd *Command) {
  402. for _, c := range cmd.Commands() {
  403. if !c.IsAvailableCommand() || c == cmd.helpCommand {
  404. continue
  405. }
  406. gen(buf, c)
  407. }
  408. commandName := cmd.CommandPath()
  409. commandName = strings.Replace(commandName, " ", "_", -1)
  410. commandName = strings.Replace(commandName, ":", "__", -1)
  411. if cmd.Root() == cmd {
  412. buf.WriteString(fmt.Sprintf("_%s_root_command()\n{\n", commandName))
  413. } else {
  414. buf.WriteString(fmt.Sprintf("_%s()\n{\n", commandName))
  415. }
  416. buf.WriteString(fmt.Sprintf(" last_command=%q\n", commandName))
  417. writeCommands(buf, cmd)
  418. writeFlags(buf, cmd)
  419. writeRequiredFlag(buf, cmd)
  420. writeRequiredNouns(buf, cmd)
  421. writeArgAliases(buf, cmd)
  422. buf.WriteString("}\n\n")
  423. }
  424. // GenBashCompletion generates bash completion file and writes to the passed writer.
  425. func (c *Command) GenBashCompletion(w io.Writer) error {
  426. buf := new(bytes.Buffer)
  427. writePreamble(buf, c.Name())
  428. if len(c.BashCompletionFunction) > 0 {
  429. buf.WriteString(c.BashCompletionFunction + "\n")
  430. }
  431. gen(buf, c)
  432. writePostscript(buf, c.Name())
  433. _, err := buf.WriteTo(w)
  434. return err
  435. }
  436. func nonCompletableFlag(flag *pflag.Flag) bool {
  437. return flag.Hidden || len(flag.Deprecated) > 0
  438. }
  439. // GenBashCompletionFile generates bash completion file.
  440. func (c *Command) GenBashCompletionFile(filename string) error {
  441. outFile, err := os.Create(filename)
  442. if err != nil {
  443. return err
  444. }
  445. defer outFile.Close()
  446. return c.GenBashCompletion(outFile)
  447. }
  448. // MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
  449. // and causes your command to report an error if invoked without the flag.
  450. func (c *Command) MarkFlagRequired(name string) error {
  451. return MarkFlagRequired(c.Flags(), name)
  452. }
  453. // MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag if it exists,
  454. // and causes your command to report an error if invoked without the flag.
  455. func (c *Command) MarkPersistentFlagRequired(name string) error {
  456. return MarkFlagRequired(c.PersistentFlags(), name)
  457. }
  458. // MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
  459. // and causes your command to report an error if invoked without the flag.
  460. func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
  461. return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
  462. }
  463. // MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists.
  464. // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
  465. func (c *Command) MarkFlagFilename(name string, extensions ...string) error {
  466. return MarkFlagFilename(c.Flags(), name, extensions...)
  467. }
  468. // MarkFlagCustom adds the BashCompCustom annotation to the named flag, if it exists.
  469. // Generated bash autocompletion will call the bash function f for the flag.
  470. func (c *Command) MarkFlagCustom(name string, f string) error {
  471. return MarkFlagCustom(c.Flags(), name, f)
  472. }
  473. // MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists.
  474. // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
  475. func (c *Command) MarkPersistentFlagFilename(name string, extensions ...string) error {
  476. return MarkFlagFilename(c.PersistentFlags(), name, extensions...)
  477. }
  478. // MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists.
  479. // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
  480. func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error {
  481. return flags.SetAnnotation(name, BashCompFilenameExt, extensions)
  482. }
  483. // MarkFlagCustom adds the BashCompCustom annotation to the named flag in the flag set, if it exists.
  484. // Generated bash autocompletion will call the bash function f for the flag.
  485. func MarkFlagCustom(flags *pflag.FlagSet, name string, f string) error {
  486. return flags.SetAnnotation(name, BashCompCustom, []string{f})
  487. }