Skip to content
This repository has been archived by the owner on Feb 17, 2024. It is now read-only.

Commit

Permalink
timecraft: simplify enum options
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed May 29, 2023
1 parent ebb757f commit 170070d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 41 deletions.
72 changes: 38 additions & 34 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"log"
_ "net/http/pprof"
"os"
"time"
"strings"

"github.com/stealthrocket/timecraft/internal/object"
"github.com/stealthrocket/timecraft/internal/print/human"
Expand Down Expand Up @@ -103,20 +103,44 @@ func Root(ctx context.Context, args ...string) int {
}
}

func setEnum[T ~string](enum *T, typ string, value string, options ...string) error {
for _, option := range options {
if option == value {
*enum = T(option)
return nil
}
}
return fmt.Errorf("unsupported %s: %q (not one of %s)", typ, value, strings.Join(options, ", "))
}

type compression string

func (c compression) String() string {
return string(c)
}

func (c *compression) Set(value string) error {
return setEnum(c, "compression type", value, "snappy", "zstd", "none")
}

type sockets string

func (s sockets) String() string {
return string(s)
}

func (s *sockets) Set(value string) error {
return setEnum(s, "sockets extension", value, "none", "auto", "path_open", "wasmedgev1", "wasmedgev2")
}

type outputFormat string

func (o outputFormat) String() string {
return string(o)
}

func (o *outputFormat) Set(value string) error {
switch value {
case "text", "json", "yaml":
*o = outputFormat(value)
return nil
default:
return fmt.Errorf("unsupported output format: %q", value)
}
return setEnum(o, "output format", value, "text", "json", "yaml")
}

type stringList []string
Expand Down Expand Up @@ -171,36 +195,16 @@ func parseFlags(f *flag.FlagSet, args []string) {
}
}

func customVar(f *flag.FlagSet, dst flag.Value, name string, alias ...string) {
f.Var(dst, name, "")
func boolVar(f *flag.FlagSet, dst *bool, name string, alias ...string) {
f.BoolVar(dst, name, *dst, "")
for _, name := range alias {
f.Var(dst, name, "")
f.BoolVar(dst, name, *dst, "")
}
}

func durationVar(f *flag.FlagSet, dst *time.Duration, name string, alias ...string) {
setFlagVar(f.DurationVar, dst, name, alias)
}

func stringVar(f *flag.FlagSet, dst *string, name string, alias ...string) {
setFlagVar(f.StringVar, dst, name, alias)
}

func boolVar(f *flag.FlagSet, dst *bool, name string, alias ...string) {
setFlagVar(f.BoolVar, dst, name, alias)
}

func intVar(f *flag.FlagSet, dst *int, name string, alias ...string) {
setFlagVar(f.IntVar, dst, name, alias)
}

func float64Var(f *flag.FlagSet, dst *float64, name string, alias ...string) {
setFlagVar(f.Float64Var, dst, name, alias)
}

func setFlagVar[T any](set func(*T, string, T, string), dst *T, name string, alias []string) {
set(dst, name, *dst, "")
func customVar(f *flag.FlagSet, dst flag.Value, name string, alias ...string) {
f.Var(dst, name, "")
for _, name := range alias {
set(dst, name, *dst, "")
f.Var(dst, name, "")
}
}
13 changes: 6 additions & 7 deletions internal/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -45,8 +44,8 @@ func run(ctx context.Context, args []string) error {
listens stringList
dials stringList
batchSize = human.Count(4096)
compression = "zstd"
sockets = "auto"
compression = compression("zstd")
sockets = sockets("auto")
registryPath = human.Path("~/.timecraft")
record = false
trace = false
Expand All @@ -56,12 +55,12 @@ func run(ctx context.Context, args []string) error {
customVar(flagSet, &envs, "e", "env")
customVar(flagSet, &listens, "L", "listen")
customVar(flagSet, &dials, "D", "dial")
stringVar(flagSet, &sockets, "S", "sockets")
customVar(flagSet, &sockets, "S", "sockets")
customVar(flagSet, &registryPath, "r", "registry")
boolVar(flagSet, &trace, "T", "trace")
boolVar(flagSet, &record, "R", "record")
customVar(flagSet, &batchSize, "record-batch-size")
stringVar(flagSet, &compression, "record-compression")
customVar(flagSet, &compression, "record-compression")
parseFlags(flagSet, args)

envs = append(os.Environ(), envs...)
Expand Down Expand Up @@ -106,12 +105,12 @@ func run(ctx context.Context, args []string) error {
WithListens(listens...).
WithDials(dials...).
WithStdio(stdin, stdout, stderr).
WithSocketsExtension(sockets, wasmModule).
WithSocketsExtension(string(sockets), wasmModule).
WithTracer(trace, os.Stderr)

if record {
var c timemachine.Compression
switch strings.ToLower(compression) {
switch compression {
case "snappy":
c = timemachine.Snappy
case "zstd":
Expand Down

0 comments on commit 170070d

Please sign in to comment.