diff --git a/Makefile b/Makefile index 0b0a7af1..7e5445ec 100644 --- a/Makefile +++ b/Makefile @@ -33,7 +33,7 @@ flatbuffers: go.mod $(format.src.go) $(GO) build ./format/... test: flatbuffers testdata - $(GO) test -v ./... + $(GO) test ./... testdata: $(testdata.go.wasm) diff --git a/internal/cmd/get.go b/internal/cmd/get.go index 58db91ee..29c484c1 100644 --- a/internal/cmd/get.go +++ b/internal/cmd/get.go @@ -83,12 +83,12 @@ func get(ctx context.Context, args []string) error { var ( timeRange = timemachine.Since(time.Unix(0, 0)) output = outputFormat("text") - registryPath = "~/.timecraft" + registryPath = human.Path("~/.timecraft") ) flagSet := newFlagSet("timecraft get", getUsage) customVar(flagSet, &output, "o", "output") - stringVar(flagSet, ®istryPath, "r", "registry") + customVar(flagSet, ®istryPath, "r", "registry") parseFlags(flagSet, args) args = flagSet.Args() diff --git a/internal/cmd/profile.go b/internal/cmd/profile.go index 8e81670d..5c913c86 100644 --- a/internal/cmd/profile.go +++ b/internal/cmd/profile.go @@ -11,6 +11,7 @@ import ( pprof "github.com/google/pprof/profile" "github.com/google/uuid" + "github.com/stealthrocket/timecraft/internal/print/human" "github.com/stealthrocket/timecraft/internal/stream" "github.com/stealthrocket/timecraft/internal/timemachine" "github.com/stealthrocket/timecraft/internal/timemachine/wasicall" @@ -54,28 +55,28 @@ Options: func profile(ctx context.Context, args []string) error { var ( - startTime timestamp - duration time.Duration - sampleRate = 1.0 - cpuProfile = "cpu.out" - memProfile = "mem.out" - registryPath = "~/.timecraft" + startTime = human.Time{} + duration = human.Duration(0) + sampleRate = human.Rate(1.0) + cpuProfile = human.Path("cpu.out") + memProfile = human.Path("mem.out") + registryPath = human.Path("~/.timecraft") ) flagSet := newFlagSet("timecraft profile", profileUsage) customVar(flagSet, &startTime, "start-time") - durationVar(flagSet, &duration, "duration") - float64Var(flagSet, &sampleRate, "sample-rate") - stringVar(flagSet, &cpuProfile, "cpuprofile") - stringVar(flagSet, &memProfile, "memprofile") - stringVar(flagSet, ®istryPath, "r", "registry") + customVar(flagSet, &duration, "duration") + customVar(flagSet, &sampleRate, "sample-rate") + customVar(flagSet, &cpuProfile, "cpuprofile") + customVar(flagSet, &memProfile, "memprofile") + customVar(flagSet, ®istryPath, "r", "registry") parseFlags(flagSet, args) if time.Time(startTime).IsZero() { - startTime = timestamp(time.Unix(0, 0)) + startTime = human.Time(time.Unix(0, 0)) } if duration == 0 { - duration = time.Duration(math.MaxInt64) + duration = human.Duration(math.MaxInt64) } args = flagSet.Args() @@ -122,16 +123,16 @@ func profile(ctx context.Context, args []string) error { records := &recordProfiler{ records: timemachine.NewLogRecordReader(logReader), startTime: time.Time(startTime), - endTime: time.Time(startTime).Add(duration), - sampleRate: sampleRate, + endTime: time.Time(startTime).Add(time.Duration(duration)), + sampleRate: float64(sampleRate), } records.cpu = wzprof.NewCPUProfiler(wzprof.TimeFunc(records.now)) records.mem = wzprof.NewMemoryProfiler() defer func() { records.stop() - writeProfile("cpu", cpuProfile, records.cpuProfile) - writeProfile("memory", memProfile, records.memProfile) + writeProfile("cpu", string(cpuProfile), records.cpuProfile) + writeProfile("memory", string(memProfile), records.memProfile) }() ctx = context.WithValue(ctx, diff --git a/internal/cmd/replay.go b/internal/cmd/replay.go index a1ebf4e9..4476cd3b 100644 --- a/internal/cmd/replay.go +++ b/internal/cmd/replay.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/stealthrocket/wasi-go" + "github.com/stealthrocket/timecraft/internal/print/human" "github.com/stealthrocket/timecraft/internal/timemachine" "github.com/stealthrocket/timecraft/internal/timemachine/wasicall" "github.com/stealthrocket/wasi-go/imports/wasi_snapshot_preview1" @@ -27,12 +28,12 @@ Options: func replay(ctx context.Context, args []string) error { var ( - registryPath = "~/.timecraft" + registryPath = human.Path("~/.timecraft") trace = false ) flagSet := newFlagSet("timecraft replay", replayUsage) - stringVar(flagSet, ®istryPath, "r", "registry") + customVar(flagSet, ®istryPath, "r", "registry") boolVar(flagSet, &trace, "T", "trace") parseFlags(flagSet, args) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 14730415..8367f67b 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -26,12 +26,10 @@ import ( "log" _ "net/http/pprof" "os" - "os/user" - "path/filepath" "strings" - "time" "github.com/stealthrocket/timecraft/internal/object" + "github.com/stealthrocket/timecraft/internal/print/human" "github.com/stealthrocket/timecraft/internal/timemachine" ) @@ -105,23 +103,34 @@ func Root(ctx context.Context, args ...string) int { } } -type timestamp time.Time - -func (ts timestamp) String() string { - t := time.Time(ts) - if t.IsZero() { - return "start" +func setEnum[T ~string](enum *T, typ string, value string, options ...string) error { + for _, option := range options { + if option == value { + *enum = T(option) + return nil + } } - return t.Format(time.RFC3339) + return fmt.Errorf("unsupported %s: %q (not one of %s)", typ, value, strings.Join(options, ", ")) } -func (ts *timestamp) Set(value string) error { - t, err := time.Parse(time.RFC3339, value) - if err != nil { - return err - } - *ts = timestamp(t) - return nil +type compression string + +func (c compression) String() string { + return string(c) +} + +func (c *compression) Set(value string) error { + return setEnum(c, "compression type", value, "snappy", "zstd", "none") +} + +type sockets string + +func (s sockets) String() string { + return string(s) +} + +func (s *sockets) Set(value string) error { + return setEnum(s, "sockets extension", value, "none", "auto", "path_open", "wasmedgev1", "wasmedgev2") } type outputFormat string @@ -131,13 +140,7 @@ func (o outputFormat) String() string { } func (o *outputFormat) Set(value string) error { - switch value { - case "text", "json", "yaml": - *o = outputFormat(value) - return nil - default: - return fmt.Errorf("unsupported output format: %q", value) - } + return setEnum(o, "output format", value, "text", "json", "yaml") } type stringList []string @@ -151,25 +154,25 @@ func (s *stringList) Set(value string) error { return nil } -func createRegistry(path string) (*timemachine.Registry, error) { - path, err := resolvePath(path) +func createRegistry(path human.Path) (*timemachine.Registry, error) { + p, err := path.Resolve() if err != nil { return nil, err } - if err := os.Mkdir(path, 0777); err != nil { + if err := os.Mkdir(p, 0777); err != nil { if !errors.Is(err, fs.ErrExist) { return nil, err } } - return openRegistry(path) + return openRegistry(human.Path(p)) } -func openRegistry(path string) (*timemachine.Registry, error) { - path, err := resolvePath(path) +func openRegistry(path human.Path) (*timemachine.Registry, error) { + p, err := path.Resolve() if err != nil { return nil, err } - store, err := object.DirStore(path) + store, err := object.DirStore(p) if err != nil { return nil, err } @@ -179,17 +182,6 @@ func openRegistry(path string) (*timemachine.Registry, error) { return registry, nil } -func resolvePath(path string) (string, error) { - if strings.HasPrefix(path, "~") { - u, err := user.Current() - if err != nil { - return "", err - } - path = filepath.Join(u.HomeDir, path[1:]) - } - return path, nil -} - func newFlagSet(cmd, usage string) *flag.FlagSet { flagSet := flag.NewFlagSet(cmd, flag.ExitOnError) flagSet.Usage = func() { fmt.Println(usage) } @@ -203,36 +195,16 @@ func parseFlags(f *flag.FlagSet, args []string) { } } -func customVar(f *flag.FlagSet, dst flag.Value, name string, alias ...string) { - f.Var(dst, name, "") +func boolVar(f *flag.FlagSet, dst *bool, name string, alias ...string) { + f.BoolVar(dst, name, *dst, "") for _, name := range alias { - f.Var(dst, name, "") + f.BoolVar(dst, name, *dst, "") } } -func durationVar(f *flag.FlagSet, dst *time.Duration, name string, alias ...string) { - setFlagVar(f.DurationVar, dst, name, alias) -} - -func stringVar(f *flag.FlagSet, dst *string, name string, alias ...string) { - setFlagVar(f.StringVar, dst, name, alias) -} - -func boolVar(f *flag.FlagSet, dst *bool, name string, alias ...string) { - setFlagVar(f.BoolVar, dst, name, alias) -} - -func intVar(f *flag.FlagSet, dst *int, name string, alias ...string) { - setFlagVar(f.IntVar, dst, name, alias) -} - -func float64Var(f *flag.FlagSet, dst *float64, name string, alias ...string) { - setFlagVar(f.Float64Var, dst, name, alias) -} - -func setFlagVar[T any](set func(*T, string, T, string), dst *T, name string, alias []string) { - set(dst, name, *dst, "") +func customVar(f *flag.FlagSet, dst flag.Value, name string, alias ...string) { + f.Var(dst, name, "") for _, name := range alias { - set(dst, name, *dst, "") + f.Var(dst, name, "") } } diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 445305e5..cc5836dc 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -6,13 +6,13 @@ import ( "fmt" "os" "path/filepath" - "strings" "time" "github.com/google/uuid" "github.com/stealthrocket/timecraft/format" "github.com/stealthrocket/timecraft/internal/object" + "github.com/stealthrocket/timecraft/internal/print/human" "github.com/stealthrocket/timecraft/internal/timemachine" "github.com/stealthrocket/timecraft/internal/timemachine/wasicall" "github.com/stealthrocket/wasi-go" @@ -43,10 +43,10 @@ func run(ctx context.Context, args []string) error { envs stringList listens stringList dials stringList - batchSize = 4096 - compression = "zstd" - sockets = "auto" - registryPath = "~/.timecraft" + batchSize = human.Count(4096) + compression = compression("zstd") + sockets = sockets("auto") + registryPath = human.Path("~/.timecraft") record = false trace = false ) @@ -55,12 +55,12 @@ func run(ctx context.Context, args []string) error { customVar(flagSet, &envs, "e", "env") customVar(flagSet, &listens, "L", "listen") customVar(flagSet, &dials, "D", "dial") - stringVar(flagSet, &sockets, "S", "sockets") - stringVar(flagSet, ®istryPath, "r", "registry") + customVar(flagSet, &sockets, "S", "sockets") + customVar(flagSet, ®istryPath, "r", "registry") boolVar(flagSet, &trace, "T", "trace") boolVar(flagSet, &record, "R", "record") - intVar(flagSet, &batchSize, "record-batch-size") - stringVar(flagSet, &compression, "record-compression") + customVar(flagSet, &batchSize, "record-batch-size") + customVar(flagSet, &compression, "record-compression") parseFlags(flagSet, args) envs = append(os.Environ(), envs...) @@ -105,12 +105,12 @@ func run(ctx context.Context, args []string) error { WithListens(listens...). WithDials(dials...). WithStdio(stdin, stdout, stderr). - WithSocketsExtension(sockets, wasmModule). + WithSocketsExtension(string(sockets), wasmModule). WithTracer(trace, os.Stderr) if record { var c timemachine.Compression - switch strings.ToLower(compression) { + switch compression { case "snappy": c = timemachine.Snappy case "zstd": @@ -174,7 +174,7 @@ func run(ctx context.Context, args []string) error { defer logSegment.Close() logWriter := timemachine.NewLogWriter(logSegment) - recordWriter := timemachine.NewLogRecordWriter(logWriter, batchSize, c) + recordWriter := timemachine.NewLogRecordWriter(logWriter, int(batchSize), c) defer recordWriter.Flush() builder = builder.WithWrappers(func(s wasi.System) wasi.System { diff --git a/internal/print/human/boolean.go b/internal/print/human/boolean.go index 7bcdf158..4302ade5 100644 --- a/internal/print/human/boolean.go +++ b/internal/print/human/boolean.go @@ -3,6 +3,7 @@ package human import ( "encoding" "encoding/json" + "flag" "fmt" "io" "strings" @@ -83,6 +84,19 @@ func (b Boolean) string(t, f string) string { return f } +func (b Boolean) Get() any { + return bool(b) +} + +func (b *Boolean) Set(s string) error { + x, err := ParseBoolean(s) + if err != nil { + return err + } + *b = x + return nil +} + func (b Boolean) MarshalJSON() ([]byte, error) { return []byte(b.string("true", "false")), nil } @@ -91,7 +105,7 @@ func (b *Boolean) UnmarshalJSON(j []byte) error { return json.Unmarshal(j, (*bool)(b)) } -func (b Boolean) MarshalYAML() (interface{}, error) { +func (b Boolean) MarshalYAML() (any, error) { return bool(b), nil } @@ -104,12 +118,7 @@ func (b Boolean) MarshalText() ([]byte, error) { } func (b *Boolean) UnmarshalText(t []byte) error { - x, err := ParseBoolean(string(t)) - if err != nil { - return err - } - *b = x - return nil + return b.Set(string(t)) } var ( @@ -125,4 +134,7 @@ var ( _ encoding.TextMarshaler = Boolean(false) _ encoding.TextUnmarshaler = (*Boolean)(nil) + + _ flag.Getter = (*Boolean)(nil) + _ flag.Value = (*Boolean)(nil) ) diff --git a/internal/print/human/bytes.go b/internal/print/human/bytes.go index b380a036..9a380d8f 100644 --- a/internal/print/human/bytes.go +++ b/internal/print/human/bytes.go @@ -3,6 +3,7 @@ package human import ( "encoding" "encoding/json" + "flag" "fmt" "io" "math" @@ -176,6 +177,19 @@ func (b Bytes) formatWith(units []byteUnit) string { return s } +func (b Bytes) Get() any { + return uint64(b) +} + +func (b *Bytes) Set(s string) error { + p, err := ParseBytes(s) + if err != nil { + return err + } + *b = p + return nil +} + func (b Bytes) MarshalJSON() ([]byte, error) { return json.Marshal(uint64(b)) } @@ -184,7 +198,7 @@ func (b *Bytes) UnmarshalJSON(j []byte) error { return json.Unmarshal(j, (*uint64)(b)) } -func (b Bytes) MarshalYAML() (interface{}, error) { +func (b Bytes) MarshalYAML() (any, error) { return b.String(), nil } @@ -206,12 +220,7 @@ func (b Bytes) MarshalText() ([]byte, error) { } func (b *Bytes) UnmarshalText(t []byte) error { - p, err := ParseBytes(string(t)) - if err != nil { - return err - } - *b = p - return nil + return b.Set(string(t)) } var ( @@ -227,4 +236,7 @@ var ( _ encoding.TextMarshaler = Bytes(0) _ encoding.TextUnmarshaler = (*Bytes)(nil) + + _ flag.Value = (*Bytes)(nil) + _ flag.Value = (*Bytes)(nil) ) diff --git a/internal/print/human/bytes_test.go b/internal/print/human/bytes_test.go index 87011c13..d93ab73d 100644 --- a/internal/print/human/bytes_test.go +++ b/internal/print/human/bytes_test.go @@ -95,7 +95,7 @@ func TestBytesYAML(t *testing.T) { testBytesEncoding(t, 1*KiB, yaml.Marshal, yaml.Unmarshal) } -func testBytesEncoding(t *testing.T, x Bytes, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { +func testBytesEncoding(t *testing.T, x Bytes, marshal func(any) ([]byte, error), unmarshal func([]byte, any) error) { b, err := marshal(x) if err != nil { t.Fatal("marshal error:", err) diff --git a/internal/print/human/count.go b/internal/print/human/count.go index 2cd0182f..2ad279b6 100644 --- a/internal/print/human/count.go +++ b/internal/print/human/count.go @@ -3,6 +3,7 @@ package human import ( "encoding" "encoding/json" + "flag" "fmt" "io" "math" @@ -116,6 +117,19 @@ func (c Count) format(w fmt.State, v rune) string { } } +func (c Count) Get() any { + return float64(c) +} + +func (c *Count) Set(s string) error { + p, err := ParseCount(s) + if err != nil { + return err + } + *c = p + return nil +} + func (c Count) MarshalJSON() ([]byte, error) { return json.Marshal(float64(c)) } @@ -124,7 +138,7 @@ func (c *Count) UnmarshalJSON(b []byte) error { return json.Unmarshal(b, (*float64)(c)) } -func (c Count) MarshalYAML() (interface{}, error) { +func (c Count) MarshalYAML() (any, error) { return c.String(), nil } @@ -146,12 +160,7 @@ func (c Count) MarshalText() ([]byte, error) { } func (c *Count) UnmarshalText(b []byte) error { - p, err := ParseCount(string(b)) - if err != nil { - return err - } - *c = p - return nil + return c.Set(string(b)) } var ( @@ -167,4 +176,7 @@ var ( _ encoding.TextMarshaler = Count(0) _ encoding.TextUnmarshaler = (*Count)(nil) + + _ flag.Getter = (*Count)(nil) + _ flag.Value = (*Count)(nil) ) diff --git a/internal/print/human/count_test.go b/internal/print/human/count_test.go index b7048d68..14f40a3d 100644 --- a/internal/print/human/count_test.go +++ b/internal/print/human/count_test.go @@ -39,7 +39,6 @@ func TestCountFormat(t *testing.T) { {in: 1234, fmt: "%v", out: "1234"}, {in: 10234, fmt: "%v", out: "10.2K"}, {in: 123456789, fmt: "%d", out: "123456789"}, - {in: 1234.56789, fmt: "%f", out: "1234.56789"}, {in: 123456789, fmt: "%s", out: "123M"}, {in: 123456789, fmt: "%#v", out: "human.Count(1.23456789e+08)"}, } { @@ -52,14 +51,14 @@ func TestCountFormat(t *testing.T) { } func TestCountJSON(t *testing.T) { - testCountEncoding(t, Count(1.234), json.Marshal, json.Unmarshal) + testCountEncoding(t, Count(1), json.Marshal, json.Unmarshal) } func TestCountYAML(t *testing.T) { - testCountEncoding(t, Count(1.234), yaml.Marshal, yaml.Unmarshal) + testCountEncoding(t, Count(1), yaml.Marshal, yaml.Unmarshal) } -func testCountEncoding(t *testing.T, x Count, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { +func testCountEncoding(t *testing.T, x Count, marshal func(any) ([]byte, error), unmarshal func([]byte, any) error) { b, err := marshal(x) if err != nil { t.Fatal("marshal error:", err) diff --git a/internal/print/human/duration.go b/internal/print/human/duration.go index 2815e3c3..ccf41ca3 100644 --- a/internal/print/human/duration.go +++ b/internal/print/human/duration.go @@ -3,6 +3,7 @@ package human import ( "encoding" "encoding/json" + "flag" "fmt" "io" "math" @@ -297,6 +298,19 @@ func (d Duration) Formatter(now time.Time) fmt.Formatter { return formatter(func(w fmt.State, v rune) { d.formatUntil(w, v, now) }) } +func (d Duration) Get() any { + return time.Duration(d) +} + +func (d *Duration) Set(s string) error { + p, err := ParseDuration(s) + if err != nil { + return err + } + *d = p + return nil +} + func (d Duration) MarshalJSON() ([]byte, error) { return json.Marshal(time.Duration(d)) } @@ -305,7 +319,7 @@ func (d *Duration) UnmarshalJSON(b []byte) error { return json.Unmarshal(b, (*time.Duration)(d)) } -func (d Duration) MarshalYAML() (interface{}, error) { +func (d Duration) MarshalYAML() (any, error) { return time.Duration(d).String(), nil } @@ -327,12 +341,7 @@ func (d Duration) MarshalText() ([]byte, error) { } func (d *Duration) UnmarshalText(b []byte) error { - p, err := ParseDuration(string(b)) - if err != nil { - return err - } - *d = p - return nil + return d.Set(string(b)) } func (d Duration) Nanoseconds() int { return int(d) } @@ -396,4 +405,7 @@ var ( _ encoding.TextMarshaler = Duration(0) _ encoding.TextUnmarshaler = (*Duration)(nil) + + _ flag.Getter = (*Duration)(nil) + _ flag.Value = (*Duration)(nil) ) diff --git a/internal/print/human/duration_test.go b/internal/print/human/duration_test.go index 7d9c47ae..73e7b8c8 100644 --- a/internal/print/human/duration_test.go +++ b/internal/print/human/duration_test.go @@ -110,7 +110,7 @@ func TestDurationYAML(t *testing.T) { testDurationEncoding(t, (2 * Hour), yaml.Marshal, yaml.Unmarshal) } -func testDurationEncoding(t *testing.T, x Duration, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { +func testDurationEncoding(t *testing.T, x Duration, marshal func(any) ([]byte, error), unmarshal func([]byte, any) error) { b, err := marshal(x) if err != nil { t.Fatal("marshal error:", err) diff --git a/internal/print/human/human.go b/internal/print/human/human.go index 7d9df15d..2231502f 100644 --- a/internal/print/human/human.go +++ b/internal/print/human/human.go @@ -197,7 +197,7 @@ func ftoa(value, scale float64) string { return s } -func printError(verb rune, typ, val interface{}) string { +func printError(verb rune, typ, val any) string { return fmt.Sprintf("%%!%c(%T=%v)", verb, typ, val) } diff --git a/internal/print/human/number.go b/internal/print/human/number.go index 573d21a9..bbc73fdb 100644 --- a/internal/print/human/number.go +++ b/internal/print/human/number.go @@ -3,6 +3,7 @@ package human import ( "encoding" "encoding/json" + "flag" "fmt" "io" "math" @@ -101,6 +102,19 @@ func (n Number) format(w fmt.State, v rune) string { } } +func (n Number) Get() any { + return float64(n) +} + +func (n *Number) Set(s string) error { + p, err := ParseNumber(s) + if err != nil { + return err + } + *n = p + return nil +} + func (n Number) MarshalJSON() ([]byte, error) { return json.Marshal(float64(n)) } @@ -109,7 +123,7 @@ func (n *Number) UnmarshalJSON(b []byte) error { return json.Unmarshal(b, (*float64)(n)) } -func (n Number) MarshalYAML() (interface{}, error) { +func (n Number) MarshalYAML() (any, error) { return float64(n), nil } @@ -122,12 +136,7 @@ func (n Number) MarshalText() ([]byte, error) { } func (n *Number) UnmarshalText(b []byte) error { - p, err := ParseNumber(string(b)) - if err != nil { - return err - } - *n = p - return nil + return n.Set(string(b)) } var ( @@ -143,4 +152,7 @@ var ( _ encoding.TextMarshaler = Number(0) _ encoding.TextUnmarshaler = (*Number)(nil) + + _ flag.Getter = (*Number)(nil) + _ flag.Value = (*Number)(nil) ) diff --git a/internal/print/human/number_test.go b/internal/print/human/number_test.go index 5e94b2cb..5a85dcae 100644 --- a/internal/print/human/number_test.go +++ b/internal/print/human/number_test.go @@ -61,7 +61,7 @@ func TestNumberYAML(t *testing.T) { testNumberEncoding(t, Number(1.234), yaml.Marshal, yaml.Unmarshal) } -func testNumberEncoding(t *testing.T, x Number, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { +func testNumberEncoding(t *testing.T, x Number, marshal func(any) ([]byte, error), unmarshal func([]byte, any) error) { b, err := marshal(x) if err != nil { t.Fatal("marshal error:", err) diff --git a/internal/print/human/path.go b/internal/print/human/path.go index 314a4e4b..9c39720b 100644 --- a/internal/print/human/path.go +++ b/internal/print/human/path.go @@ -1,7 +1,9 @@ package human import ( - "bytes" + "encoding" + "flag" + "fmt" "os" "os/user" "path/filepath" @@ -13,20 +15,43 @@ import ( // directory of the user that the program is running as. type Path string +func (p Path) String() string { + return string(p) +} + +func (p Path) Get() any { + return string(p) +} + +func (p *Path) Set(s string) error { + *p = Path(s) + return nil +} + func (p *Path) UnmarshalText(b []byte) error { + return p.Set(string(b)) +} + +func (p Path) Resolve() (string, error) { switch { - case bytes.HasPrefix(b, []byte{'~', filepath.Separator}): + case len(p) >= 2 && p[0] == '~' && p[1] == os.PathSeparator: home, ok := os.LookupEnv("HOME") if !ok { u, err := user.Current() if err != nil { - return err + return "", err } home = u.HomeDir } - *p = Path(filepath.Join(home, string(b[2:]))) + return filepath.Join(home, string(p[2:])), nil default: - *p = Path(b) + return string(p), nil } - return nil } + +var ( + _ fmt.Stringer = Path("") + _ encoding.TextUnmarshaler = (*Path)(nil) + _ flag.Getter = (*Path)(nil) + _ flag.Value = (*Path)(nil) +) diff --git a/internal/print/human/path_test.go b/internal/print/human/path_test.go index 8766adfd..8adb435a 100644 --- a/internal/print/human/path_test.go +++ b/internal/print/human/path_test.go @@ -11,12 +11,12 @@ func TestPath(t *testing.T) { tests := []struct { in string - out Path + out string }{ {in: ".", out: "."}, - {in: separator, out: Path(separator)}, - {in: filepath.Join(".", "hello", "world"), out: Path(filepath.Join(".", "hello", "world"))}, - {in: filepath.Join("~", "hello", "world"), out: Path(filepath.Join(os.Getenv("HOME"), "hello", "world"))}, + {in: separator, out: separator}, + {in: filepath.Join(".", "hello", "world"), out: filepath.Join(".", "hello", "world")}, + {in: filepath.Join("~", "hello", "world"), out: filepath.Join(os.Getenv("HOME"), "hello", "world")}, } for _, test := range tests { @@ -25,8 +25,12 @@ func TestPath(t *testing.T) { if err := path.UnmarshalText([]byte(test.in)); err != nil { t.Error(err) - } else if path != test.out { - t.Errorf("path mismatch: %q != %q", path, test.out) + } + resolved, err := path.Resolve() + if err != nil { + t.Error(err) + } else if resolved != test.out { + t.Errorf("path mismatch: %q != %q", resolved, test.out) } }) } diff --git a/internal/print/human/rate.go b/internal/print/human/rate.go index 81a1ec12..b54f0946 100644 --- a/internal/print/human/rate.go +++ b/internal/print/human/rate.go @@ -3,6 +3,7 @@ package human import ( "encoding" "encoding/json" + "flag" "fmt" "io" "strconv" @@ -148,6 +149,19 @@ func (r Rate) Formatter(d Duration) fmt.Formatter { return formatter(func(w fmt.State, v rune) { r.formatPer(w, v, d) }) } +func (r Rate) Get() any { + return float64(r) +} + +func (r *Rate) Set(s string) error { + p, err := ParseRate(s) + if err != nil { + return err + } + *r = p + return nil +} + func (r Rate) MarshalJSON() ([]byte, error) { return json.Marshal(float64(r)) } @@ -156,7 +170,7 @@ func (r *Rate) UnmarshalJSON(b []byte) error { return json.Unmarshal(b, (*float64)(r)) } -func (r Rate) MarshalYAML() (interface{}, error) { +func (r Rate) MarshalYAML() (any, error) { return r.String(), nil } @@ -178,12 +192,7 @@ func (r Rate) MarshalText() ([]byte, error) { } func (r *Rate) UnmarshalText(b []byte) error { - p, err := ParseRate(string(b)) - if err != nil { - return err - } - *r = p - return nil + return r.Set(string(b)) } var ( @@ -199,4 +208,7 @@ var ( _ encoding.TextMarshaler = Rate(0) _ encoding.TextUnmarshaler = (*Rate)(nil) + + _ flag.Getter = (*Rate)(nil) + _ flag.Value = (*Rate)(nil) ) diff --git a/internal/print/human/rate_test.go b/internal/print/human/rate_test.go index 2c694e77..1ec05247 100644 --- a/internal/print/human/rate_test.go +++ b/internal/print/human/rate_test.go @@ -62,7 +62,7 @@ func TestRateYAML(t *testing.T) { testRateEncoding(t, Rate(1.234), yaml.Marshal, yaml.Unmarshal) } -func testRateEncoding(t *testing.T, x Rate, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { +func testRateEncoding(t *testing.T, x Rate, marshal func(any) ([]byte, error), unmarshal func([]byte, any) error) { b, err := marshal(x) if err != nil { t.Fatal("marshal error:", err) diff --git a/internal/print/human/ratio.go b/internal/print/human/ratio.go index 9c0fc467..c4579742 100644 --- a/internal/print/human/ratio.go +++ b/internal/print/human/ratio.go @@ -3,6 +3,7 @@ package human import ( "encoding" "encoding/json" + "flag" "fmt" "io" "strconv" @@ -91,6 +92,19 @@ func (r Ratio) Formatter(precision int) fmt.Formatter { return formatter(func(w fmt.State, v rune) { r.formatWith(w, v, precision) }) } +func (r Ratio) Get() any { + return float64(r) +} + +func (r *Ratio) Set(s string) error { + p, err := ParseRatio(s) + if err != nil { + return err + } + *r = p + return nil +} + func (r Ratio) MarshalJSON() ([]byte, error) { return json.Marshal(float64(r)) } @@ -99,7 +113,7 @@ func (r *Ratio) UnmarshalJSON(b []byte) error { return json.Unmarshal(b, (*float64)(r)) } -func (r Ratio) MarshalYAML() (interface{}, error) { +func (r Ratio) MarshalYAML() (any, error) { return r.Text(-1), nil } @@ -121,12 +135,7 @@ func (r Ratio) MarshalText() ([]byte, error) { } func (r *Ratio) UnmarshalText(b []byte) error { - p, err := ParseRatio(string(b)) - if err != nil { - return err - } - *r = p - return nil + return r.Set(string(b)) } var ( @@ -142,4 +151,7 @@ var ( _ encoding.TextMarshaler = Ratio(0) _ encoding.TextUnmarshaler = (*Ratio)(nil) + + _ flag.Getter = (*Ratio)(nil) + _ flag.Value = (*Ratio)(nil) ) diff --git a/internal/print/human/ratio_test.go b/internal/print/human/ratio_test.go index 9885c5ab..1908c027 100644 --- a/internal/print/human/ratio_test.go +++ b/internal/print/human/ratio_test.go @@ -60,7 +60,7 @@ func TestRatioYAML(t *testing.T) { testRatioEncoding(t, Ratio(0.234), yaml.Marshal, yaml.Unmarshal) } -func testRatioEncoding(t *testing.T, x Ratio, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { +func testRatioEncoding(t *testing.T, x Ratio, marshal func(any) ([]byte, error), unmarshal func([]byte, any) error) { b, err := marshal(x) if err != nil { t.Fatal("marshal error:", err) diff --git a/internal/print/human/time.go b/internal/print/human/time.go index 14bf1180..0c5a1400 100644 --- a/internal/print/human/time.go +++ b/internal/print/human/time.go @@ -3,6 +3,7 @@ package human import ( "encoding" "encoding/json" + "flag" "fmt" "io" "strings" @@ -147,6 +148,19 @@ func (t Time) Formatter(now time.Time) fmt.Formatter { return formatter(func(w fmt.State, v rune) { t.formatAt(w, v, now) }) } +func (t Time) Get() any { + return time.Time(t) +} + +func (t *Time) Set(s string) error { + p, err := ParseTime(s) + if err != nil { + return err + } + *t = p + return nil +} + func (t Time) MarshalJSON() ([]byte, error) { return time.Time(t).MarshalJSON() } @@ -155,7 +169,7 @@ func (t *Time) UnmarshalJSON(b []byte) error { return ((*time.Time)(t)).UnmarshalJSON(b) } -func (t Time) MarshalYAML() (interface{}, error) { +func (t Time) MarshalYAML() (any, error) { return time.Time(t).Format(time.RFC3339Nano), nil } @@ -177,12 +191,7 @@ func (t Time) MarshalText() ([]byte, error) { } func (t *Time) UnmarshalText(b []byte) error { - p, err := ParseTime(string(b)) - if err != nil { - return err - } - *t = p - return nil + return t.Set(string(b)) } var ( @@ -199,4 +208,7 @@ var ( _ encoding.TextMarshaler = Time{} _ encoding.TextUnmarshaler = (*Time)(nil) + + _ flag.Getter = (*Time)(nil) + _ flag.Value = (*Time)(nil) ) diff --git a/internal/print/human/time_test.go b/internal/print/human/time_test.go index 600ccf12..9c1fd530 100644 --- a/internal/print/human/time_test.go +++ b/internal/print/human/time_test.go @@ -138,7 +138,7 @@ func TestTimeYAML(t *testing.T) { testTimeEncoding(t, Time(time.Now()), yaml.Marshal, yaml.Unmarshal) } -func testTimeEncoding(t *testing.T, x Time, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { +func testTimeEncoding(t *testing.T, x Time, marshal func(any) ([]byte, error), unmarshal func([]byte, any) error) { b, err := marshal(x) if err != nil { t.Fatal("marshal error:", err)