diff --git a/Makefile b/Makefile index 7838e556..0b0a7af1 100644 --- a/Makefile +++ b/Makefile @@ -14,8 +14,9 @@ format.src.go = \ timecraft.src.go = \ $(format.src.go) \ $(wildcard *.go) \ - $(wildcard cmd/*.go) \ - $(wildcard internal/*/*.go) + $(wildcard */*.go) \ + $(wildcard */*/*.go) \ + $(wildcard */*/*/*.go) timecraft: go.mod $(timecraft.src.go) $(GO) build -o timecraft @@ -23,6 +24,9 @@ timecraft: go.mod $(timecraft.src.go) clean: rm -f timecraft $(format.src.go) $(testdata.go.wasm) +lint: + golangci-lint run ./... + generate: flatbuffers flatbuffers: go.mod $(format.src.go) diff --git a/format/timecraft.go b/format/timecraft.go index 864371f9..06b9d322 100644 --- a/format/timecraft.go +++ b/format/timecraft.go @@ -27,6 +27,15 @@ func SHA256(b []byte) Hash { } } +func ParseHash(s string) (h Hash, err error) { + var ok bool + h.Algorithm, h.Digest, ok = strings.Cut(s, ":") + if !ok { + err = fmt.Errorf("malformed hash: %s", s) + } + return h, err +} + func (h Hash) String() string { return h.Algorithm + ":" + h.Digest } @@ -60,6 +69,8 @@ const ( TypeTimecraftModule MediaType = "application/vnd.timecraft.module.v1+wasm" ) +func (m MediaType) String() string { return string(m) } + type Resource interface { ContentType() MediaType } @@ -75,11 +86,11 @@ type ResourceUnmarshaler interface { } type Descriptor struct { - MediaType MediaType `json:"mediaType"` - Digest Hash `json:"digest"` - Size int64 `json:"size"` - URLs []string `json:"urls,omitempty"` - Annotations map[string]string `json:"annotations,omitempty"` + MediaType MediaType `json:"mediaType" yaml:"mediaType"` + Digest Hash `json:"digest" yaml:"digest"` + Size int64 `json:"size" yaml:"size"` + URLs []string `json:"urls,omitempty" yaml:"urls,omitempty"` + Annotations map[string]string `json:"annotations,omitempty" yaml:"annotations,omitempty"` } func (d *Descriptor) ContentType() MediaType { @@ -112,7 +123,7 @@ func (m *Module) UnmarshalResource(b []byte) error { } type Runtime struct { - Version string `json:"version"` + Version string `json:"version" yaml:"version"` } func (r *Runtime) ContentType() MediaType { @@ -128,10 +139,10 @@ func (r *Runtime) UnmarshalResource(b []byte) error { } type Config struct { - Runtime *Descriptor `json:"runtime"` - Modules []*Descriptor `json:"modules"` - Args []string `json:"args"` - Env []string `json:"env,omitempty"` + Runtime *Descriptor `json:"runtime" yaml:"runtime"` + Modules []*Descriptor `json:"modules" yaml:"modules"` + Args []string `json:"args" yaml:"args"` + Env []string `json:"env,omitempty" yaml:"env,omitempty"` } func (c *Config) ContentType() MediaType { @@ -147,9 +158,9 @@ func (c *Config) UnmarshalResource(b []byte) error { } type Process struct { - ID UUID `json:"id"` - StartTime time.Time `json:"startTime"` - Config *Descriptor `json:"config"` + ID UUID `json:"id" yaml:"id"` + StartTime time.Time `json:"startTime" yaml:"startTime"` + Config *Descriptor `json:"config" yaml:"config"` } func (p *Process) ContentType() MediaType { @@ -178,8 +189,8 @@ func jsonDecode(b []byte, value any) error { } type Manifest struct { - Process *Descriptor `json:"process"` - StartTime time.Time `json:"startTime"` + Process *Descriptor `json:"process" yaml:"process"` + StartTime time.Time `json:"startTime" yaml:"startTime"` } func (m *Manifest) ContentType() MediaType { diff --git a/go.mod b/go.mod index 556abc42..78657b8a 100644 --- a/go.mod +++ b/go.mod @@ -15,4 +15,7 @@ require ( golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 ) -require golang.org/x/sys v0.8.0 // indirect +require ( + golang.org/x/sys v0.8.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 505c1906..d07c66dc 100644 --- a/go.sum +++ b/go.sum @@ -12,8 +12,6 @@ github.com/stealthrocket/wasi-go v0.1.1 h1:9Q9zpKWItoObGjNG5kkllzHx1sksiq/MKfuYd github.com/stealthrocket/wasi-go v0.1.1/go.mod h1:LBhZHvAroNNQTejkVTMJZ01ssj3jXF+3Lkbru4cTzGQ= github.com/stealthrocket/wazergo v0.19.0 h1:0ZBya2fBURvV+I2hGl0vcuQ8dgoUvllxQ7aYlZSA5nI= github.com/stealthrocket/wazergo v0.19.0/go.mod h1:riI0hxw4ndZA5e6z7PesHg2BtTftcZaMxRcoiGGipTs= -github.com/stealthrocket/wzprof v0.1.4 h1:Yb/JHAQIpzCrpr/Nw/rgZxqqTigW2HT8SKNs6SLGFV4= -github.com/stealthrocket/wzprof v0.1.4/go.mod h1:lUNsjcNEjviNBV8+MhOGGNBI/SQa7miJQaoXOTgRRok= github.com/stealthrocket/wzprof v0.1.5-0.20230526193557-ec6e2ad60848 h1:gNZnxEbv7OgKkGvRU4PtGDZpzls81FV0IYoUW3I46l0= github.com/stealthrocket/wzprof v0.1.5-0.20230526193557-ec6e2ad60848/go.mod h1:hqLzj5iDSncc6rlPMhC51O642AkaC+dWVPNNalZdlCY= github.com/tetratelabs/wazero v1.1.1-0.20230522055633-256b7a4bf970 h1:X5OOeHRjoLA8XhVc7biEbh1/hnTzpYpPn7HuyarMslQ= @@ -22,3 +20,6 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERs golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cmd/get.go b/internal/cmd/get.go new file mode 100644 index 00000000..58db91ee --- /dev/null +++ b/internal/cmd/get.go @@ -0,0 +1,274 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "strings" + "time" + + "github.com/stealthrocket/timecraft/format" + "github.com/stealthrocket/timecraft/internal/print/human" + "github.com/stealthrocket/timecraft/internal/print/jsonprint" + "github.com/stealthrocket/timecraft/internal/print/textprint" + "github.com/stealthrocket/timecraft/internal/print/yamlprint" + "github.com/stealthrocket/timecraft/internal/stream" + "github.com/stealthrocket/timecraft/internal/timemachine" +) + +const getUsage = ` +Usage: timecraft get [options] + + The get sub-command gives access to the state of the time machine registry. + The command must be followed by the name of resources to display, which must + be one of config, module, process, or runtime. + (the command also accepts plurals and abbreviations of the resource names) + +Examples: + + $ timecraft get modules + MODULE ID MODULE NAME SIZE + 9d7b7563baf3 app.wasm 6.82 MiB + + $ timecraft get modules -o json + { + "mediaType": "application/vnd.timecraft.module.v1+wasm", + "digest": "sha256:9d7b7563baf3702cf24ed3688dc9a58faef2d0ac586041cb2dc95df919f5e5f2", + "size": 7150231, + "annotations": { + "timecraft.module.name": "app.wasm", + "timecraft.object.created-at": "2023-05-28T21:52:26Z", + "timecraft.object.resource-type": "module" + } + } + +Options: + -h, --help Show this usage information + -o, --ouptut format Output format, one of: text, json, yaml + -r, --registry path Path to the timecraft registry (default to ~/.timecraft) +` + +type resource struct { + name string + alt []string + get func(context.Context, io.Writer, *timemachine.Registry) stream.WriteCloser[*format.Descriptor] +} + +var resources = [...]resource{ + { + name: "config", + alt: []string{"conf", "configs"}, + get: getConfigs, + }, + { + name: "module", + alt: []string{"mo", "mod", "mods", "modules"}, + get: getModules, + }, + { + name: "process", + alt: []string{"ps", "procs", "processes"}, + get: getProcesses, + }, + { + name: "runtime", + alt: []string{"rt", "runtimes"}, + get: getRuntimes, + }, +} + +func get(ctx context.Context, args []string) error { + var ( + timeRange = timemachine.Since(time.Unix(0, 0)) + output = outputFormat("text") + registryPath = "~/.timecraft" + ) + + flagSet := newFlagSet("timecraft get", getUsage) + customVar(flagSet, &output, "o", "output") + stringVar(flagSet, ®istryPath, "r", "registry") + parseFlags(flagSet, args) + + args = flagSet.Args() + if len(args) == 0 { + return errors.New(`expected exactly one resource name as argument` + useGet()) + } + resourceNamePrefix := args[0] + parseFlags(flagSet, args[1:]) + + resource, ok := findResource(resourceNamePrefix, resources[:]) + if !ok { + matchingResources := findMatchingResources(resourceNamePrefix, resources[:]) + if len(matchingResources) == 0 { + return fmt.Errorf(`no resources matching '%s'`+useGet(), resourceNamePrefix) + } + return fmt.Errorf(`no resources matching '%s' + +Did you mean?%s`, resourceNamePrefix, joinResourceNames(matchingResources, "\n ")) + } + + registry, err := openRegistry(registryPath) + if err != nil { + return err + } + + reader := registry.ListResources(ctx, resource.name, timeRange) + defer reader.Close() + + var writer stream.WriteCloser[*format.Descriptor] + switch output { + case "json": + writer = jsonprint.NewWriter[*format.Descriptor](os.Stdout) + case "yaml": + writer = yamlprint.NewWriter[*format.Descriptor](os.Stdout) + default: + writer = resource.get(ctx, os.Stdout, registry) + } + defer writer.Close() + + _, err = stream.Copy[*format.Descriptor](writer, reader) + return err +} + +func getConfigs(ctx context.Context, w io.Writer, r *timemachine.Registry) stream.WriteCloser[*format.Descriptor] { + type config struct { + ID string `text:"CONFIG ID"` + Runtime string `text:"RUNTIME"` + Modules int `text:"MODULES"` + Size human.Bytes `text:"SIZE"` + } + return newDescTableWriter(w, func(desc *format.Descriptor) (config, error) { + c, err := r.LookupConfig(ctx, desc.Digest) + if err != nil { + return config{}, err + } + r, err := r.LookupRuntime(ctx, c.Runtime.Digest) + if err != nil { + return config{}, err + } + return config{ + ID: desc.Digest.Digest[:12], + Runtime: r.Version, + Modules: len(c.Modules), + Size: human.Bytes(desc.Size), + }, nil + }) +} + +func getModules(ctx context.Context, w io.Writer, r *timemachine.Registry) stream.WriteCloser[*format.Descriptor] { + type module struct { + ID string `text:"MODULE ID"` + Name string `text:"MODULE NAME"` + Size human.Bytes `text:"SIZE"` + } + return newDescTableWriter(w, func(desc *format.Descriptor) (module, error) { + name := desc.Annotations["timecraft.module.name"] + if name == "" { + name = "(none)" + } + return module{ + ID: desc.Digest.Digest[:12], + Name: name, + Size: human.Bytes(desc.Size), + }, nil + }) +} + +func getProcesses(ctx context.Context, w io.Writer, r *timemachine.Registry) stream.WriteCloser[*format.Descriptor] { + type process struct { + ID format.UUID `text:"PROCESS ID"` + StartTime human.Time `text:"STARTED"` + } + return newDescTableWriter(w, func(desc *format.Descriptor) (process, error) { + p, err := r.LookupProcess(ctx, desc.Digest) + if err != nil { + return process{}, err + } + return process{ + ID: p.ID, + StartTime: human.Time(p.StartTime), + }, nil + }) +} + +func getRuntimes(ctx context.Context, w io.Writer, r *timemachine.Registry) stream.WriteCloser[*format.Descriptor] { + type runtime struct { + ID string `text:"RUNTIME ID"` + Version string `text:"VERSION"` + CreatedAt human.Time `text:"CREATED"` + } + return newDescTableWriter(w, func(desc *format.Descriptor) (runtime, error) { + r, err := r.LookupRuntime(ctx, desc.Digest) + if err != nil { + return runtime{}, err + } + t, err := human.ParseTime(desc.Annotations["timecraft.object.created-at"]) + if err != nil { + return runtime{}, err + } + return runtime{ + ID: desc.Digest.Digest[:12], + Version: r.Version, + CreatedAt: t, + }, nil + }) +} + +func newDescTableWriter[T any](w io.Writer, conv func(*format.Descriptor) (T, error)) stream.WriteCloser[*format.Descriptor] { + tw := textprint.NewTableWriter[T](w) + cw := stream.ConvertWriter[T](tw, conv) + return stream.NewWriteCloser(cw, tw) +} + +func findResource(name string, options []resource) (resource, bool) { + for _, option := range options { + if option.name == name { + return option, true + } + for _, alt := range option.alt { + if alt == name { + return option, true + } + } + } + return resource{}, false +} + +func findMatchingResources(name string, options []resource) (matches []resource) { + for _, option := range options { + if prefixLength(option.name, name) > 1 || prefixLength(name, option.name) > 1 { + matches = append(matches, option) + } + } + return matches +} + +func prefixLength(base, prefix string) int { + n := 0 + for n < len(base) && n < len(prefix) && base[n] == prefix[n] { + n++ + } + return n +} + +func joinResourceNames(resources []resource, prefix string) string { + s := new(strings.Builder) + for _, r := range resources { + s.WriteString(prefix) + s.WriteString(r.name) + } + return s.String() +} + +func useGet() string { + s := new(strings.Builder) + s.WriteString("\n\n") + s.WriteString(`Use 'timecraft ' where the supported resource names are:`) + for _, r := range resources { + s.WriteString("\n ") + s.WriteString(r.name) + } + return s.String() +} diff --git a/internal/cmd/help.go b/internal/cmd/help.go index e29b613d..758a9de3 100644 --- a/internal/cmd/help.go +++ b/internal/cmd/help.go @@ -8,6 +8,9 @@ import ( const helpUsage = ` Usage: timecraft [options] +Registry Commands: + get Display resources from the time machine registry + Runtime Commands: run Run a WebAssembly module, and optionally trace execution replay Replay a recorded trace of execution @@ -33,6 +36,8 @@ func help(ctx context.Context, args []string) error { } switch cmd { + case "get": + msg = getUsage case "help", "": msg = helpUsage case "profile": diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 4804493c..14730415 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -78,6 +78,8 @@ func Root(ctx context.Context, args ...string) int { var err error cmd, args := args[0], args[1:] switch cmd { + case "get": + err = get(ctx, args) case "help": err = help(ctx, args) case "profile": @@ -122,6 +124,22 @@ func (ts *timestamp) Set(value string) error { return nil } +type outputFormat string + +func (o outputFormat) String() string { + return string(o) +} + +func (o *outputFormat) Set(value string) error { + switch value { + case "text", "json", "yaml": + *o = outputFormat(value) + return nil + default: + return fmt.Errorf("unsupported output format: %q", value) + } +} + type stringList []string func (s stringList) String() string { @@ -151,11 +169,14 @@ func openRegistry(path string) (*timemachine.Registry, error) { if err != nil { return nil, err } - dir, err := object.DirStore(path) + store, err := object.DirStore(path) if err != nil { return nil, err } - return timemachine.NewRegistry(dir), nil + registry := &timemachine.Registry{ + Store: store, + } + return registry, nil } func resolvePath(path string) (string, error) { diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 4dc87742..445305e5 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" "github.com/stealthrocket/timecraft/format" + "github.com/stealthrocket/timecraft/internal/object" "github.com/stealthrocket/timecraft/internal/timemachine" "github.com/stealthrocket/timecraft/internal/timemachine/wasicall" "github.com/stealthrocket/wasi-go" @@ -125,6 +126,9 @@ func run(ctx context.Context, args []string) error { module, err := registry.CreateModule(ctx, &format.Module{ Code: wasmCode, + }, object.Tag{ + Name: "timecraft.module.name", + Value: wasmModule.Name(), }) if err != nil { return err diff --git a/internal/object/query/query.go b/internal/object/query/query.go new file mode 100644 index 00000000..579b1cbc --- /dev/null +++ b/internal/object/query/query.go @@ -0,0 +1,67 @@ +package query + +import "time" + +type Value interface { + After(time.Time) bool + Before(time.Time) bool + Match(name, value string) bool +} + +type Filter[T Value] interface { + Match(T) bool +} + +type After[T Value] time.Time + +func (op After[T]) Match(value T) bool { + return value.After(time.Time(op)) +} + +type Before[T Value] time.Time + +func (op Before[T]) Match(value T) bool { + return value.Before(time.Time(op)) +} + +type Match[T Value] [2]string + +func (op Match[T]) Match(value T) bool { + return value.Match(op[0], op[1]) +} + +type And[T Value] [2]Filter[T] + +func (op And[T]) Match(value T) bool { + return op[0].Match(value) && op[1].Match(value) +} + +type Or[T Value] [2]Filter[T] + +func (op Or[T]) Match(value T) bool { + return op[0].Match(value) || op[1].Match(value) +} + +type Not[T Value] [1]Filter[T] + +func (op Not[T]) Match(value T) bool { + return !op[0].Match(value) +} + +func MatchAll[T Value](value T, filters ...Filter[T]) bool { + for _, filter := range filters { + if !filter.Match(value) { + return false + } + } + return true +} + +func MatchOne[T Value](value T, filters ...Filter[T]) bool { + for _, filter := range filters { + if filter.Match(value) { + return true + } + } + return false +} diff --git a/internal/object/store.go b/internal/object/store.go index 267bc99f..26862204 100644 --- a/internal/object/store.go +++ b/internal/object/store.go @@ -1,22 +1,7 @@ package object -// Goals: -// - can store webassembly module byte code (no duplicates) -// - can store log segments -// - can find logs of processes by id -// - can be implemented by an object store (e.g. S3) -// - can store snapshots of log segments -// -// Opportunities: -// - could compact runtime configuration (no duplicates) -// -// Questions: -// - how to represent runs which are mutable? -// * hash of manifest/metadata? -// * one layer per segment or one file per segment? -// - should we rely more on the storage layer to model the log? - import ( + "bytes" "context" "errors" "fmt" @@ -28,7 +13,9 @@ import ( "strings" "time" + "github.com/stealthrocket/timecraft/internal/object/query" "github.com/stealthrocket/timecraft/internal/stream" + "golang.org/x/exp/slices" ) var ( @@ -42,6 +29,56 @@ var ( ErrReadOnly = errors.New("read only object store") ) +// Tag represents a name/value pair attached to an object. +type Tag struct { + Name string + Value string +} + +func AppendTags(buf []byte, tags ...Tag) []byte { + for _, tag := range tags { + buf = append(buf, tag.Name...) + buf = append(buf, '=') + buf = append(buf, tag.Value...) + buf = append(buf, '\n') + } + return buf +} + +func (t Tag) String() string { + return t.Name + "=" + t.Value +} + +// Filter represents a predicate applicated to objects to determine whether they +// are part of the result of a ListObject operation. +type Filter = query.Filter[*Info] + +func AFTER(t time.Time) Filter { return query.After[*Info](t) } + +func BEFORE(t time.Time) Filter { return query.Before[*Info](t) } + +func MATCH(name, value string) Filter { return query.Match[*Info]{name, value} } + +func AND(f1, f2 Filter) Filter { return query.And[*Info]{f1, f2} } + +func OR(f1, f2 Filter) Filter { return query.Or[*Info]{f1, f2} } + +func NOT(f Filter) Filter { return query.Not[*Info]{f} } + +func validTag(tag Tag) bool { + return validTagName(tag.Name) && validTagValue(tag.Value) +} + +func validTagName(name string) bool { + return name != "" && + strings.IndexByte(name, '=') < 0 && + strings.IndexByte(name, '\n') < 0 +} + +func validTagValue(value string) bool { + return strings.IndexByte(value, '\n') < 0 +} + // Store is an interface abstracting an object storage layer. // // Once created, objects are immutable, the store does not need to offer a @@ -54,7 +91,7 @@ type Store interface { // // The creation of objects is atomic, the store must be left unchanged if // an error occurs that would cause the object to be only partially created. - CreateObject(ctx context.Context, name string, data io.Reader) error + CreateObject(ctx context.Context, name string, data io.Reader, tags ...Tag) error // Reads an existing object from the store, returning a reader exposing its // content. @@ -67,7 +104,7 @@ type Store interface { // // Objects that are being created by a call to CreateObject are not visible // until the creation completed. - ListObjects(ctx context.Context, prefix string) stream.ReadCloser[Info] + ListObjects(ctx context.Context, prefix string, filters ...Filter) stream.ReadCloser[Info] // Deletes and object from the store. // @@ -82,6 +119,33 @@ type Info struct { Name string Size int64 CreatedAt time.Time + Tags []Tag +} + +func (info *Info) After(t time.Time) bool { + return info.CreatedAt.After(t) +} + +func (info *Info) Before(t time.Time) bool { + return info.CreatedAt.Before(t) +} + +func (info *Info) Match(name, value string) bool { + for _, tag := range info.Tags { + if tag.Name == name && tag.Value == value { + return true + } + } + return false +} + +func (info *Info) Lookup(name string) (string, bool) { + for _, tag := range info.Tags { + if tag.Name == name { + return tag.Value, true + } + } + return "", false } // EmptyStore returns a Store instance representing an empty, read-only object @@ -90,23 +154,23 @@ func EmptyStore() Store { return emptyStore{} } type emptyStore struct{} -func (emptyStore) CreateObject(ctx context.Context, name string, data io.Reader) error { +func (emptyStore) CreateObject(context.Context, string, io.Reader, ...Tag) error { return ErrReadOnly } -func (emptyStore) ReadObject(ctx context.Context, name string) (io.ReadCloser, error) { +func (emptyStore) ReadObject(context.Context, string) (io.ReadCloser, error) { return nil, ErrNotExist } -func (emptyStore) StatObject(ctx context.Context, name string) (Info, error) { +func (emptyStore) StatObject(context.Context, string) (Info, error) { return Info{}, ErrNotExist } -func (emptyStore) ListObjects(ctx context.Context, prefix string) stream.ReadCloser[Info] { +func (emptyStore) ListObjects(context.Context, string, ...Filter) stream.ReadCloser[Info] { return emptyInfoReader{} } -func (emptyStore) DeleteObject(ctx context.Context, name string) error { +func (emptyStore) DeleteObject(context.Context, string) error { return nil } @@ -126,32 +190,69 @@ func DirStore(path string) (Store, error) { type dirStore string -func (store dirStore) CreateObject(ctx context.Context, name string, data io.Reader) error { +func (store dirStore) CreateObject(ctx context.Context, name string, data io.Reader, tags ...Tag) error { filePath, err := store.joinPath(name) if err != nil { return err } dirPath, fileName := filepath.Split(filePath) + if strings.HasPrefix(fileName, ".") { + return fmt.Errorf("object names cannot start with a dot: %s", name) + } + if err := os.MkdirAll(dirPath, 0777); err != nil { return err } - file, err := os.CreateTemp(dirPath, "."+fileName+".*") + var tagsPath string + var tagsData []byte + if len(tags) > 0 { + for _, tag := range tags { + if !validTag(tag) { + return fmt.Errorf("invalid tag: %q=%q", tag.Name, tag.Value) + } + } + + tagsPath = filepath.Join(dirPath, ".tags", fileName) + tagsData = make([]byte, 0, 256) + tagsData = AppendTags(tagsData, tags...) + + if err := os.Mkdir(filepath.Join(dirPath, ".tags"), 0777); err != nil { + if !errors.Is(err, fs.ErrExist) { + return err + } + } + if err := os.WriteFile(tagsPath, tagsData, 0666); err != nil { + return err + } + } + + objectFile, err := os.CreateTemp(dirPath, "."+fileName+".*") if err != nil { return err } - defer file.Close() - tmpPath := file.Name() + defer objectFile.Close() + + tmpPath := objectFile.Name() + success := false + defer func() { + if !success { + os.Remove(tmpPath) + if tagsPath != "" { + os.Remove(tagsPath) + } + } + }() - if _, err := io.Copy(file, data); err != nil { - os.Remove(tmpPath) + if _, err := io.Copy(objectFile, data); err != nil { return err } if err := os.Rename(tmpPath, filePath); err != nil { - os.Remove(tmpPath) return err } + + success = true return nil } @@ -172,7 +273,15 @@ func (store dirStore) StatObject(ctx context.Context, name string) (Info, error) if err != nil { return Info{}, err } - stat, err := os.Stat(path) + stat, err := os.Lstat(path) + if err != nil { + return Info{}, err + } + if !stat.Mode().IsRegular() { + return Info{}, ErrNotExist + } + dir, base := filepath.Split(path) + tags, err := readTags(filepath.Join(dir, ".tags", base)) if err != nil { return Info{}, err } @@ -180,11 +289,12 @@ func (store dirStore) StatObject(ctx context.Context, name string) (Info, error) Name: name, Size: stat.Size(), CreatedAt: stat.ModTime(), + Tags: tags, } return info, nil } -func (store dirStore) ListObjects(ctx context.Context, prefix string) stream.ReadCloser[Info] { +func (store dirStore) ListObjects(ctx context.Context, prefix string, filters ...Filter) stream.ReadCloser[Info] { if prefix != "." && !strings.HasSuffix(prefix, "/") { prefix += "/" } @@ -200,7 +310,12 @@ func (store dirStore) ListObjects(ctx context.Context, prefix string) stream.Rea return &errorInfoReader{err: err} } } - return &dirReader{dir: dir, path: filepath.ToSlash(prefix)} + return &dirReader{ + dir: dir, + path: path, + prefix: prefix, + filters: slices.Clone(filters), + } } func (store dirStore) DeleteObject(ctx context.Context, name string) error { @@ -218,15 +333,41 @@ func (store dirStore) DeleteObject(ctx context.Context, name string) error { } func (store dirStore) joinPath(name string) (string, error) { - if !fs.ValidPath(name) { + if name = path.Clean(name); !fs.ValidPath(name) { return "", fmt.Errorf("invalid object name: %q", name) } return filepath.Join(string(store), filepath.FromSlash(name)), nil } +func readTags(path string) ([]Tag, error) { + b, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + err = nil + } + return nil, err + } + tags := make([]Tag, 0, 8) + nl := []byte{'\n'} + eq := []byte{'='} + for _, line := range bytes.Split(b, nl) { + name, value, ok := bytes.Cut(line, eq) + if !ok { + continue + } + tags = append(tags, Tag{ + Name: string(name), + Value: string(value), + }) + } + return tags, nil +} + type dirReader struct { - dir *os.File - path string + dir *os.File + path string + prefix string + filters []Filter } func (r *dirReader) Close() error { @@ -239,10 +380,6 @@ func (r *dirReader) Read(items []Info) (n int, err error) { dirents, err := r.dir.ReadDir(len(items) - n) for _, dirent := range dirents { - if dirent.IsDir() { - continue - } - name := dirent.Name() if strings.HasPrefix(name, ".") { continue @@ -250,16 +387,25 @@ func (r *dirReader) Read(items []Info) (n int, err error) { info, err := dirent.Info() if err != nil { - r.dir.Close() + return n, err + } + + tagsPath := filepath.Join(r.path, ".tags", name) + tags, err := readTags(tagsPath) + if err != nil { return n, err } items[n] = Info{ - Name: path.Join(r.path, name), + Name: path.Join(r.prefix, name), Size: info.Size(), CreatedAt: info.ModTime(), + Tags: tags, + } + + if query.MatchAll(&items[n], r.filters...) { + n++ } - n++ } if err != nil { diff --git a/internal/object/store_test.go b/internal/object/store_test.go index 6ea1be6d..97da178f 100644 --- a/internal/object/store_test.go +++ b/internal/object/store_test.go @@ -69,6 +69,11 @@ func testObjectStore(t *testing.T, newStore func(*testing.T) (object.Store, func scenario: "objects being created are not visible when listing", function: testObjectStoreListWhileCreate, }, + + { + scenario: "tagged objects are filtered when listing", + function: testObjectStoreListTaggedObjects, + }, } for _, test := range tests { @@ -115,11 +120,8 @@ func testObjectStoreCreateAndList(t *testing.T, ctx context.Context, store objec assert.OK(t, store.CreateObject(ctx, "test-2", strings.NewReader("A"))) assert.OK(t, store.CreateObject(ctx, "test-3", strings.NewReader("BC"))) - objects := readValues(t, store.ListObjects(ctx, ".")) - clearCreatedAt(objects) - sortObjectInfo(objects) - - assert.EqualAll(t, objects, []object.Info{ + objects := listObjects(t, ctx, store, ".") + assert.DeepEqual(t, objects, []object.Info{ {Name: "test-1", Size: 0}, {Name: "test-2", Size: 1}, {Name: "test-3", Size: 2}, @@ -150,11 +152,8 @@ func testObjectStoreDeleteAndList(t *testing.T, ctx context.Context, store objec assert.OK(t, store.CreateObject(ctx, "test-3", strings.NewReader("BC"))) assert.OK(t, store.DeleteObject(ctx, "test-2")) - objects := readValues(t, store.ListObjects(ctx, ".")) - clearCreatedAt(objects) - sortObjectInfo(objects) - - assert.EqualAll(t, objects, []object.Info{ + objects := listObjects(t, ctx, store, ".") + assert.DeepEqual(t, objects, []object.Info{ {Name: "test-1", Size: 0}, {Name: "test-3", Size: 2}, }) @@ -175,11 +174,8 @@ func testObjectStoreListWhileCreate(t *testing.T, ctx context.Context, store obj _, err := io.WriteString(w, "H") assert.OK(t, err) - beforeCreateObject := readValues(t, store.ListObjects(ctx, ".")) - clearCreatedAt(beforeCreateObject) - sortObjectInfo(beforeCreateObject) - - assert.EqualAll(t, beforeCreateObject, []object.Info{ + beforeCreateObject := listObjects(t, ctx, store, ".") + assert.DeepEqual(t, beforeCreateObject, []object.Info{ {Name: "test-1", Size: 0}, {Name: "test-2", Size: 1}, }) @@ -189,17 +185,76 @@ func testObjectStoreListWhileCreate(t *testing.T, ctx context.Context, store obj assert.OK(t, w.Close()) <-done - afterCreateObject := readValues(t, store.ListObjects(ctx, ".")) - clearCreatedAt(afterCreateObject) - sortObjectInfo(afterCreateObject) - - assert.EqualAll(t, afterCreateObject, []object.Info{ + afterCreateObject := listObjects(t, ctx, store, ".") + assert.DeepEqual(t, afterCreateObject, []object.Info{ {Name: "test-1", Size: 0}, {Name: "test-2", Size: 1}, {Name: "test-3", Size: 12}, }) } +func testObjectStoreListTaggedObjects(t *testing.T, ctx context.Context, store object.Store) { + assert.OK(t, store.CreateObject(ctx, "test-1", strings.NewReader(""))) // no tags + assert.OK(t, store.CreateObject(ctx, "test-2", strings.NewReader("A"), + object.Tag{"tag-1", "value-1"}, + object.Tag{"tag-2", "value-2"}, + )) + assert.OK(t, store.CreateObject(ctx, "test-3", strings.NewReader("BC"), + object.Tag{"tag-1", "value-1"}, + object.Tag{"tag-1", "value-2"}, + object.Tag{"tag-2", "value-3"}, + )) + + object1 := object.Info{Name: "test-1", Size: 0} + object2 := object.Info{Name: "test-2", Size: 1, Tags: []object.Tag{{"tag-1", "value-1"}, {"tag-2", "value-2"}}} + object3 := object.Info{Name: "test-3", Size: 2, Tags: []object.Tag{{"tag-1", "value-1"}, {"tag-1", "value-2"}, {"tag-2", "value-3"}}} + + assert.DeepEqual(t, + listObjects(t, ctx, store, "."), + []object.Info{object1, object2, object3}) + + assert.DeepEqual(t, + listObjects(t, ctx, store, ".", object.MATCH("tag-1", "value-1")), + []object.Info{object2, object3}) + + assert.DeepEqual(t, + listObjects(t, ctx, store, ".", object.MATCH("tag-1", "value-2")), + []object.Info{object3}) + + assert.DeepEqual(t, + listObjects(t, ctx, store, ".", object.MATCH("tag-2", "value-2")), + []object.Info{object2}) + + assert.DeepEqual(t, + listObjects(t, ctx, store, ".", object.MATCH("tag-2", "value-3")), + []object.Info{object3}) + + assert.DeepEqual(t, + listObjects(t, ctx, store, ".", + object.OR( + object.MATCH("tag-2", "value-2"), + object.MATCH("tag-2", "value-3"), + ), + ), + []object.Info{object2, object3}) + + assert.DeepEqual(t, + listObjects(t, ctx, store, ".", + object.AND( + object.MATCH("tag-1", "value-2"), + object.MATCH("tag-2", "value-3"), + ), + ), + []object.Info{object3}) +} + +func listObjects(t *testing.T, ctx context.Context, store object.Store, prefix string, filters ...object.Filter) []object.Info { + objects := readValues(t, store.ListObjects(ctx, ".", filters...)) + clearCreatedAt(objects) + sortObjectInfo(objects) + return objects +} + func readBytes(t *testing.T, r io.ReadCloser) []byte { t.Helper() defer r.Close() diff --git a/internal/print/human/LICENSE b/internal/print/human/LICENSE new file mode 100644 index 00000000..9c3721a6 --- /dev/null +++ b/internal/print/human/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Segment + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/internal/print/human/boolean.go b/internal/print/human/boolean.go new file mode 100644 index 00000000..7bcdf158 --- /dev/null +++ b/internal/print/human/boolean.go @@ -0,0 +1,128 @@ +package human + +import ( + "encoding" + "encoding/json" + "fmt" + "io" + "strings" + + yaml "gopkg.in/yaml.v3" +) + +// Boolean returns a boolean value. +// +// The type supports parsing values as "true", "false", "yes", or "no", all +// case insensitive. +type Boolean bool + +func ParseBoolean(s string) (Boolean, error) { + switch strings.ToLower(s) { + case "true", "yes": + return true, nil + case "false", "no": + return false, nil + default: + return false, fmt.Errorf("invalid boolean representation: %q", s) + } +} + +// String satisfies the fmt.Stringer interface, returns "yes" or "no". +func (b Boolean) String() string { return b.string("yes", "no") } + +// GoString satisfies the fmt.GoStringer interface. +func (b Boolean) GoString() string { + return fmt.Sprintf("human.Boolean(%t)", bool(b)) +} + +// Format satisfies the fmt.Formatter interface. +// +// The method supports the following formatting verbse: +// +// s "yes" or "no" +// t "true" or "false" +// v same as 's' +// +// For each of these options, these extra flags are also intepreted: +// +// Capitalized + +// All uppercase # +func (b Boolean) Format(w fmt.State, v rune) { + _, _ = io.WriteString(w, b.format(w, v)) +} + +func (b Boolean) format(w fmt.State, v rune) string { + switch v { + case 's', 'v': + switch { + case w.Flag('#'): + return b.string("YES", "NO") + case w.Flag('+'): + return b.string("Yes", "No") + default: + return b.string("yes", "no") + } + case 't': + switch { + case w.Flag('#'): + return b.string("TRUE", "FALSE") + case w.Flag('+'): + return b.string("True", "False") + default: + return b.string("true", "false") + } + default: + return printError(v, b, bool(b)) + } +} + +func (b Boolean) string(t, f string) string { + if b { + return t + } + return f +} + +func (b Boolean) MarshalJSON() ([]byte, error) { + return []byte(b.string("true", "false")), nil +} + +func (b *Boolean) UnmarshalJSON(j []byte) error { + return json.Unmarshal(j, (*bool)(b)) +} + +func (b Boolean) MarshalYAML() (interface{}, error) { + return bool(b), nil +} + +func (b *Boolean) UnmarshalYAML(y *yaml.Node) error { + return y.Decode((*bool)(b)) +} + +func (b Boolean) MarshalText() ([]byte, error) { + return []byte(b.String()), nil +} + +func (b *Boolean) UnmarshalText(t []byte) error { + x, err := ParseBoolean(string(t)) + if err != nil { + return err + } + *b = x + return nil +} + +var ( + _ fmt.Formatter = Boolean(false) + _ fmt.GoStringer = Boolean(false) + _ fmt.Stringer = Boolean(false) + + _ json.Marshaler = Boolean(false) + _ json.Unmarshaler = (*Boolean)(nil) + + _ yaml.Marshaler = Boolean(false) + _ yaml.Unmarshaler = (*Boolean)(nil) + + _ encoding.TextMarshaler = Boolean(false) + _ encoding.TextUnmarshaler = (*Boolean)(nil) +) diff --git a/internal/print/human/boolean_test.go b/internal/print/human/boolean_test.go new file mode 100644 index 00000000..0fa06629 --- /dev/null +++ b/internal/print/human/boolean_test.go @@ -0,0 +1,71 @@ +package human + +import ( + "fmt" + "testing" +) + +func TestBooleanParse(t *testing.T) { + for _, test := range []struct { + in string + out Boolean + }{ + {in: "true", out: true}, + {in: "True", out: true}, + {in: "TRUE", out: true}, + + {in: "false", out: false}, + {in: "False", out: false}, + {in: "FALSE", out: false}, + + {in: "yes", out: true}, + {in: "Yes", out: true}, + {in: "YES", out: true}, + + {in: "no", out: false}, + {in: "No", out: false}, + {in: "NO", out: false}, + } { + t.Run(test.in, func(t *testing.T) { + b, err := ParseBoolean(test.in) + if err != nil { + t.Fatal(err) + } + if b != test.out { + t.Error("parsed boolean mismatch:", b, "!=", test.out) + } + }) + } +} + +func TestBooleanFormat(t *testing.T) { + for _, test := range []struct { + in Boolean + fmt string + out string + }{ + {in: true, fmt: "%s", out: "yes"}, + {in: true, fmt: "%t", out: "true"}, + + {in: true, fmt: "%+s", out: "Yes"}, + {in: true, fmt: "%+t", out: "True"}, + + {in: true, fmt: "%#s", out: "YES"}, + {in: true, fmt: "%#t", out: "TRUE"}, + + {in: false, fmt: "%s", out: "no"}, + {in: false, fmt: "%t", out: "false"}, + + {in: false, fmt: "%+s", out: "No"}, + {in: false, fmt: "%+t", out: "False"}, + + {in: false, fmt: "%#s", out: "NO"}, + {in: false, fmt: "%#t", out: "FALSE"}, + } { + t.Run(test.out, func(t *testing.T) { + if s := fmt.Sprintf(test.fmt, test.in); s != test.out { + t.Error("formatted boolean mismatch:", s, "!=", test.out) + } + }) + } +} diff --git a/internal/print/human/bytes.go b/internal/print/human/bytes.go new file mode 100644 index 00000000..b380a036 --- /dev/null +++ b/internal/print/human/bytes.go @@ -0,0 +1,230 @@ +package human + +import ( + "encoding" + "encoding/json" + "fmt" + "io" + "math" + "strconv" + + yaml "gopkg.in/yaml.v3" +) + +// Bytes represents a number of bytes. +// +// The type support parsing values in formats like: +// +// 42 KB +// 8Gi +// 1.5KiB +// ... +// +// Two models are supported, using factors of 1000 and factors of 1024 via units +// like KB, MB, GB for the former, or Ki, Mi, MiB for the latter. +// +// In the current implementation, formatting is always done in factors of 1024, +// using units like Ki, Mi, Gi etc... +// +// Values may be decimals when using units larger than B. Partial bytes cannot +// be represnted (e.g. 0.5B is not supported). +type Bytes uint64 + +const ( + B Bytes = 1 + + KB Bytes = 1000 * B + MB Bytes = 1000 * KB + GB Bytes = 1000 * MB + TB Bytes = 1000 * GB + PB Bytes = 1000 * TB + + KiB Bytes = 1024 * B + MiB Bytes = 1024 * KiB + GiB Bytes = 1024 * MiB + TiB Bytes = 1024 * GiB + PiB Bytes = 1024 * TiB +) + +func ParseBytes(s string) (Bytes, error) { + f, err := ParseBytesFloat64(s) + if err != nil { + return 0, err + } + if f < 0 { + return 0, fmt.Errorf("invalid negative byte count: %q", s) + } + return Bytes(math.Floor(f)), err +} + +func ParseBytesFloat64(s string) (float64, error) { + value, unit := parseUnit(s) + + scale := Bytes(0) + switch { + case match(unit, "B"), unit == "": + scale = B + case match(unit, "KB"): + scale = KB + case match(unit, "MB"): + scale = MB + case match(unit, "GB"): + scale = GB + case match(unit, "TB"): + scale = TB + case match(unit, "PB"): + scale = PB + case match(unit, "KiB"): + scale = KiB + case match(unit, "MiB"): + scale = MiB + case match(unit, "GiB"): + scale = GiB + case match(unit, "TiB"): + scale = TiB + case match(unit, "PiB"): + scale = PiB + default: + return 0, fmt.Errorf("malformed bytes representation: %q", s) + } + + f, err := strconv.ParseFloat(value, 64) + if err != nil { + return 0, fmt.Errorf("malformed bytes representations: %q: %w", s, err) + } + return f * float64(scale), nil +} + +type byteUnit struct { + scale Bytes + unit string +} + +var bytes1000 = [...]byteUnit{ + {B, "B"}, + {KB, "KB"}, + {MB, "MB"}, + {GB, "GB"}, + {TB, "TB"}, + {PB, "PB"}, +} + +var bytes1024 = [...]byteUnit{ + {B, ""}, + {KiB, "KiB"}, + {MiB, "MiB"}, + {GiB, "GiB"}, + {TiB, "TiB"}, + {PiB, "PiB"}, +} + +func (b Bytes) String() string { + return b.formatWith(bytes1024[:]) +} + +func (b Bytes) GoString() string { + return fmt.Sprintf("human.Bytes(%d)", uint64(b)) +} + +// Format satisfies the fmt.Formatter interface. +// +// The method supports the following formatting verbs: +// +// d base 10, unit-less +// b base 10, with unit using 1000 factors +// s base 10, with unit using 1024 factors (same as calling String) +// v same as the 's' format, unless '#' is set to print the go value +func (b Bytes) Format(w fmt.State, v rune) { + _, _ = io.WriteString(w, b.format(w, v)) +} + +func (b Bytes) format(w fmt.State, v rune) string { + switch v { + case 'd': + return strconv.FormatUint(uint64(b), 10) + case 'b': + return b.formatWith(bytes1000[:]) + case 's': + return b.formatWith(bytes1024[:]) + case 'v': + if w.Flag('#') { + return b.GoString() + } + return b.format(w, 's') + default: + return printError(v, b, uint64(b)) + } +} + +func (b Bytes) formatWith(units []byteUnit) string { + var scale Bytes + var unit string + + for i := len(units) - 1; i >= 0; i-- { + u := units[i] + + if b >= u.scale { + scale, unit = u.scale, u.unit + break + } + } + + s := ftoa(float64(b), float64(scale)) + if unit != "" { + s += " " + unit + } + return s +} + +func (b Bytes) MarshalJSON() ([]byte, error) { + return json.Marshal(uint64(b)) +} + +func (b *Bytes) UnmarshalJSON(j []byte) error { + return json.Unmarshal(j, (*uint64)(b)) +} + +func (b Bytes) MarshalYAML() (interface{}, error) { + return b.String(), nil +} + +func (b *Bytes) UnmarshalYAML(y *yaml.Node) error { + var s string + if err := y.Decode(&s); err != nil { + return err + } + p, err := ParseBytes(s) + if err != nil { + return err + } + *b = p + return nil +} + +func (b Bytes) MarshalText() ([]byte, error) { + return []byte(b.String()), nil +} + +func (b *Bytes) UnmarshalText(t []byte) error { + p, err := ParseBytes(string(t)) + if err != nil { + return err + } + *b = p + return nil +} + +var ( + _ fmt.Formatter = Bytes(0) + _ fmt.GoStringer = Bytes(0) + _ fmt.Stringer = Bytes(0) + + _ json.Marshaler = Bytes(0) + _ json.Unmarshaler = (*Bytes)(nil) + + _ yaml.Marshaler = Bytes(0) + _ yaml.Unmarshaler = (*Bytes)(nil) + + _ encoding.TextMarshaler = Bytes(0) + _ encoding.TextUnmarshaler = (*Bytes)(nil) +) diff --git a/internal/print/human/bytes_test.go b/internal/print/human/bytes_test.go new file mode 100644 index 00000000..87011c13 --- /dev/null +++ b/internal/print/human/bytes_test.go @@ -0,0 +1,110 @@ +package human + +import ( + "encoding/json" + "fmt" + "testing" + + yaml "gopkg.in/yaml.v3" +) + +func TestBytesParse(t *testing.T) { + for _, test := range []struct { + in string + out Bytes + }{ + {in: "0", out: 0}, + + {in: "2B", out: 2}, + {in: "2K", out: 2 * KB}, + {in: "2M", out: 2 * MB}, + {in: "2G", out: 2 * GB}, + {in: "2T", out: 2 * TB}, + {in: "2P", out: 2 * PB}, + + {in: "2", out: 2}, + {in: "2 KiB", out: 2 * KiB}, + {in: "2 MiB", out: 2 * MiB}, + {in: "2 GiB", out: 2 * GiB}, + {in: "2 TiB", out: 2 * TiB}, + {in: "2 PiB", out: 2 * PiB}, + + {in: "1.234 K", out: 1234}, + {in: "1.234 M", out: 1234 * KB}, + + {in: "1.5 Ki", out: 1*KiB + 512}, + {in: "1.5 Mi", out: 1*MiB + 512*KiB}, + } { + t.Run(test.in, func(t *testing.T) { + b, err := ParseBytes(test.in) + if err != nil { + t.Fatal(err) + } + if b != test.out { + t.Error("parsed bytes mismatch:", b, "!=", test.out) + } + }) + } +} + +func TestBytesFormat(t *testing.T) { + for _, test := range []struct { + in Bytes + fmt string + out string + }{ + {fmt: "%v", out: "0", in: 0}, + {fmt: "%v", out: "2", in: 2}, + + {fmt: "%v", out: "1.95 KiB", in: 2 * KB}, + {fmt: "%v", out: "1.91 MiB", in: 2 * MB}, + {fmt: "%v", out: "1.86 GiB", in: 2 * GB}, + {fmt: "%v", out: "1.82 TiB", in: 2 * TB}, + {fmt: "%v", out: "1.78 PiB", in: 2 * PB}, + + {fmt: "%v", out: "2 KiB", in: 2 * KiB}, + {fmt: "%v", out: "2 MiB", in: 2 * MiB}, + {fmt: "%v", out: "2 GiB", in: 2 * GiB}, + {fmt: "%v", out: "2 TiB", in: 2 * TiB}, + {fmt: "%v", out: "2 PiB", in: 2 * PiB}, + + {fmt: "%v", out: "1.21 KiB", in: 1234}, + {fmt: "%v", out: "1.18 MiB", in: 1234 * KB}, + + {fmt: "%v", out: "1.5 KiB", in: 1*KiB + 512}, + {fmt: "%v", out: "1.5 MiB", in: 1*MiB + 512*KiB}, + + {fmt: "%d", out: "123456789", in: 123456789}, + {fmt: "%b", out: "123 MB", in: 123456789}, + {fmt: "%s", out: "118 MiB", in: 123456789}, + {fmt: "%#v", out: "human.Bytes(123456789)", in: 123456789}, + } { + t.Run(test.out, func(t *testing.T) { + if s := fmt.Sprintf(test.fmt, test.in); s != test.out { + t.Error("formatted bytes mismatch:", s, "!=", test.out) + } + }) + } +} + +func TestBytesJSON(t *testing.T) { + testBytesEncoding(t, 1*KiB, json.Marshal, json.Unmarshal) +} + +func TestBytesYAML(t *testing.T) { + testBytesEncoding(t, 1*KiB, yaml.Marshal, yaml.Unmarshal) +} + +func testBytesEncoding(t *testing.T, x Bytes, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { + b, err := marshal(x) + if err != nil { + t.Fatal("marshal error:", err) + } + + v := Bytes(0) + if err := unmarshal(b, &v); err != nil { + t.Error("unmarshal error:", err) + } else if v != x { + t.Error("value mismatch:", v, "!=", x) + } +} diff --git a/internal/print/human/count.go b/internal/print/human/count.go new file mode 100644 index 00000000..2cd0182f --- /dev/null +++ b/internal/print/human/count.go @@ -0,0 +1,170 @@ +package human + +import ( + "encoding" + "encoding/json" + "fmt" + "io" + "math" + "strconv" + + yaml "gopkg.in/yaml.v3" +) + +// Count represents a count without a unit. +// +// The type supports parsing and formatting values like: +// +// 1234 +// 10 K +// 1.5M +// ... +type Count float64 + +const ( + K Count = 1000 + M Count = 1000 * K + G Count = 1000 * M + T Count = 1000 * G + P Count = 1000 * T +) + +func ParseCount(s string) (Count, error) { + value, unit := parseUnit(s) + + scale := Count(0) + switch { + case unit == "": + scale = 1 + case match(unit, "K"): + scale = K + case match(unit, "M"): + scale = M + case match(unit, "G"): + scale = G + case match(unit, "T"): + scale = T + case match(unit, "P"): + scale = P + default: + return 0, fmt.Errorf("malformed count representation: %q", s) + } + + f, err := strconv.ParseFloat(value, 64) + if err != nil { + return 0, fmt.Errorf("malformed count representation: %q: %w", s, err) + } + return Count(f) * scale, nil +} + +func (c Count) String() string { + var scale Count + var unit string + var f = float64(c) + + switch c = Count(fabs(f)); { + case c >= P: + scale, unit = P, "P" + case c >= T: + scale, unit = T, "T" + case c >= G: + scale, unit = G, "G" + case c >= M: + scale, unit = M, "M" + case c >= 10*K: + scale, unit = K, "K" + default: + scale, unit = 1, "" + } + + return ftoa(f, float64(scale)) + unit +} + +func (c Count) GoString() string { + return fmt.Sprintf("human.Count(%v)", float64(c)) +} + +// Format satisfies the fmt.Formatter interface. +// +// The method supports the following formatting verbs: +// +// d base 10, unit-less, rounded to the nearest integer +// e base 10, unit-less, scientific notation +// f base 10, unit-less, decimal notation +// g base 10, unit-less, act like 'e' or 'f' depending on scale +// s base 10, with unit (same as calling String) +// v same as the 's' format, unless '#' is set to print the go value +func (c Count) Format(w fmt.State, v rune) { + _, _ = io.WriteString(w, c.format(w, v)) +} + +func (c Count) format(w fmt.State, v rune) string { + switch v { + case 'd': + return ftoa(math.Round(float64(c)), 1) + case 'e', 'f', 'g': + return strconv.FormatFloat(float64(c), byte(v), -1, 64) + case 's': + return c.String() + case 'v': + if w.Flag('#') { + return c.GoString() + } + return c.format(w, 's') + default: + return printError(v, c, float64(c)) + } +} + +func (c Count) MarshalJSON() ([]byte, error) { + return json.Marshal(float64(c)) +} + +func (c *Count) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, (*float64)(c)) +} + +func (c Count) MarshalYAML() (interface{}, error) { + return c.String(), nil +} + +func (c *Count) UnmarshalYAML(y *yaml.Node) error { + var s string + if err := y.Decode(&s); err != nil { + return err + } + p, err := ParseCount(s) + if err != nil { + return err + } + *c = p + return nil +} + +func (c Count) MarshalText() ([]byte, error) { + return []byte(c.String()), nil +} + +func (c *Count) UnmarshalText(b []byte) error { + p, err := ParseCount(string(b)) + if err != nil { + return err + } + *c = p + return nil +} + +var ( + _ fmt.Formatter = Count(0) + _ fmt.GoStringer = Count(0) + _ fmt.Stringer = Count(0) + + _ json.Marshaler = Count(0) + _ json.Unmarshaler = (*Count)(nil) + + _ yaml.Marshaler = Count(0) + _ yaml.Unmarshaler = (*Count)(nil) + + _ encoding.TextMarshaler = Count(0) + _ encoding.TextUnmarshaler = (*Count)(nil) +) diff --git a/internal/print/human/count_test.go b/internal/print/human/count_test.go new file mode 100644 index 00000000..b7048d68 --- /dev/null +++ b/internal/print/human/count_test.go @@ -0,0 +1,74 @@ +package human + +import ( + "encoding/json" + "fmt" + "testing" + + yaml "gopkg.in/yaml.v3" +) + +func TestCountParse(t *testing.T) { + for _, test := range []struct { + in string + out Count + }{ + {in: "0", out: 0}, + {in: "1234", out: 1234}, + {in: "10.2K", out: 10200}, + } { + t.Run(test.in, func(t *testing.T) { + c, err := ParseCount(test.in) + if err != nil { + t.Fatal(err) + } + if c != test.out { + t.Error("parsed count mismatch:", c, "!=", test.out) + } + }) + } +} + +func TestCountFormat(t *testing.T) { + for _, test := range []struct { + in Count + fmt string + out string + }{ + {in: 0, fmt: "%v", out: "0"}, + {in: 1234, fmt: "%v", out: "1234"}, + {in: 10234, fmt: "%v", out: "10.2K"}, + {in: 123456789, fmt: "%d", out: "123456789"}, + {in: 1234.56789, fmt: "%f", out: "1234.56789"}, + {in: 123456789, fmt: "%s", out: "123M"}, + {in: 123456789, fmt: "%#v", out: "human.Count(1.23456789e+08)"}, + } { + t.Run(test.out, func(t *testing.T) { + if s := fmt.Sprintf(test.fmt, test.in); s != test.out { + t.Error("formatted count mismatch:", s, "!=", test.out) + } + }) + } +} + +func TestCountJSON(t *testing.T) { + testCountEncoding(t, Count(1.234), json.Marshal, json.Unmarshal) +} + +func TestCountYAML(t *testing.T) { + testCountEncoding(t, Count(1.234), yaml.Marshal, yaml.Unmarshal) +} + +func testCountEncoding(t *testing.T, x Count, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { + b, err := marshal(x) + if err != nil { + t.Fatal("marshal error:", err) + } + + v := Count(0) + if err := unmarshal(b, &v); err != nil { + t.Error("unmarshal error:", err) + } else if v != x { + t.Error("value mismatch:", v, "!=", x) + } +} diff --git a/internal/print/human/duration.go b/internal/print/human/duration.go new file mode 100644 index 00000000..2815e3c3 --- /dev/null +++ b/internal/print/human/duration.go @@ -0,0 +1,399 @@ +package human + +import ( + "encoding" + "encoding/json" + "fmt" + "io" + "math" + "strconv" + "strings" + "time" + + yaml "gopkg.in/yaml.v3" +) + +const ( + Nanosecond Duration = 1 + Microsecond Duration = 1000 * Nanosecond + Millisecond Duration = 1000 * Microsecond + Second Duration = 1000 * Millisecond + Minute Duration = 60 * Second + Hour Duration = 60 * Minute + Day Duration = 24 * Hour + Week Duration = 7 * Day +) + +// Duration is based on time.Duration, but supports parsing and formatting +// more human-friendly representations. +// +// Here are examples of supported values: +// +// 5m30s +// 1d +// 4 weeks +// 1.5y +// ... +// +// The current implementation does not support decimal values, however, +// contributions are welcome to add this feature. +// +// Time being what it is, months and years are hard to represent because their +// durations vary in unpredictable ways. This is why the package only exposes +// constants up to a 1 week duration. For the sake of accuracy, years and months +// are always represented relative to a given date. Technically, leap seconds +// can cause any unit above the second to be variable, but in order to remain +// mentaly sane, we chose to ignore this detail in the implementation of this +// package. +type Duration time.Duration + +func ParseDuration(s string) (Duration, error) { + return ParseDurationUntil(s, time.Now()) +} + +func ParseDurationUntil(s string, now time.Time) (Duration, error) { + var d Duration + var input = s + + if s == "0" { + return 0, nil + } + + for len(s) != 0 { + // parse the next number + + n, r, err := parseFloat(s) + if err != nil { + return 0, fmt.Errorf("malformed duration: %s: %w", input, err) + } + s = r + + // parse "weeks", "days", "h", etc. + if s == "" { + return 0, fmt.Errorf("please include a unit ('weeks', 'h', 'm') in addition to the value (%f)", n) + } + v, r, err := parseDuration(s, n, now) + if err != nil { + return 0, fmt.Errorf("malformed duration: %s: %w", input, err) + } + s = r + + d += v + } + + return d, nil +} + +func parseDuration(s string, n float64, now time.Time) (Duration, string, error) { + s, r := parseNextToken(s) + switch { + case match(s, "weeks"): + return Duration(n * float64(Week)), r, nil + case match(s, "days"): + return Duration(n * float64(Day)), r, nil + case match(s, "hours"): + return Duration(n * float64(Hour)), r, nil + case match(s, "minutes"): + return Duration(n * float64(Minute)), r, nil + case match(s, "seconds"): + return Duration(n * float64(Second)), r, nil + case match(s, "milliseconds"), s == "ms": + return Duration(n * float64(Millisecond)), r, nil + case match(s, "microseconds"), s == "us", s == "µs": + return Duration(n * float64(Microsecond)), r, nil + case match(s, "nanoseconds"), s == "ns": + return Duration(n * float64(Nanosecond)), r, nil + case match(s, "months"): + month, day := math.Modf(n) + month, day = -month, -math.Round(28*day) // 1 month is approximately 4 weeks + return Duration(now.Sub(now.AddDate(0, int(month), int(day)))), r, nil + case match(s, "years"): + year, month := math.Modf(n) + year, month = -year, -math.Round(12*month) + return Duration(now.Sub(now.AddDate(int(year), int(month), 0))), r, nil + default: + return 0, "", fmt.Errorf("unkonwn time unit %q", s) + } +} + +type durationUnits struct { + nanosecond string + microsecond string + millisecond string + second string + minute string + hour string + day string + week string + month string + year string + separator string +} + +func (durationUnits) fix(n int, s string) string { + if n == 1 && len(s) > 3 { + return s[:len(s)-1] // trim tralinig 's' on long units + } + return s +} + +var durationsShort = durationUnits{ + nanosecond: "ns", + microsecond: "µs", + millisecond: "ms", + second: "s", + minute: "m", + hour: "h", + day: "d", + week: "w", + month: "mo", + year: "y", + separator: "", +} + +var durationsLong = durationUnits{ + nanosecond: "nanoseconds", + microsecond: "microseconds", + millisecond: "milliseconds", + second: "seconds", + minute: "minutes", + hour: "hours", + day: "days", + week: "weeks", + month: "months", + year: "years", + separator: " ", +} + +func (d Duration) String() string { + return d.text(time.Now(), 1, durationsShort) +} + +func (d Duration) GoString() string { + return fmt.Sprintf("human.Duration(%d)", int64(d)) +} + +// Format satisfies the fmt.Formatter interface. +// +// The method supports the following formatting verbs: +// +// s outputs a string representation of the duration (same as calling String) +// v same as the 's' format, unless '#' is set to print the go value +// +// The 's' and 'v' formatting verbs also interpret the options: +// +// '-' outputs full names of the time units instead of abbreviations +// '.' followed by a digit to limit the precision of the output +func (d Duration) Format(w fmt.State, v rune) { + d.formatUntil(w, v, time.Now()) +} + +func (d Duration) formatUntil(w fmt.State, v rune, now time.Time) { + _, _ = io.WriteString(w, d.format(w, v, now)) +} + +func (d Duration) format(w fmt.State, v rune, now time.Time) string { + switch v { + case 's': + var limit int + var units durationUnits + + limit, hasLimit := w.Precision() + if !hasLimit { + limit = 1 + } + if w.Flag('+') { + units = durationsLong + } else { + units = durationsShort + } + + return d.text(now, limit, units) + case 'v': + if w.Flag('#') { + return d.GoString() + } + return d.format(w, 's', now) + default: + return printError(v, d, uint64(d)) + } +} + +func (d Duration) Text(now time.Time) string { + return d.text(now, 1, durationsLong) +} + +func (d Duration) text(now time.Time, limit int, units durationUnits) string { + if d == 0 { + return "0" + units.separator + units.second + } + + if d == Duration(math.MaxInt64) || d == Duration(math.MinInt64) { + return "a while" // special values for unknown durations + } + + if d < 0 { + return "-" + (-d).text(now, limit, units) + } + + var n int + var s strings.Builder + + for i := 0; d != 0; i++ { + var unit string + + if i != 0 { + s.WriteString(units.separator) + } + + if d < 31*Day { + var scale Duration + + switch { + case d < Microsecond: + scale, unit = Nanosecond, units.nanosecond + case d < Millisecond: + scale, unit = Microsecond, units.microsecond + case d < Second: + scale, unit = Millisecond, units.millisecond + case d < Minute: + scale, unit = Second, units.second + case d < Hour: + scale, unit = Minute, units.minute + case d < Day: + scale, unit = Hour, units.hour + case d < Week: + scale, unit = Day, units.day + default: + scale, unit = Week, units.week + } + + n = int(d / scale) + d -= Duration(n) * scale + + } else if n = d.Years(now); n != 0 { + d -= Duration(now.Sub(now.AddDate(-n, 0, 0))) + unit = units.year + + } else { + n = d.Months(now) + d -= Duration(now.Sub(now.AddDate(0, -n, 0))) + unit = units.month + } + + s.WriteString(strconv.Itoa(n)) + s.WriteString(units.separator) + s.WriteString(units.fix(n, unit)) + + if limit--; limit == 0 { + break + } + } + + return s.String() +} + +func (d Duration) Formatter(now time.Time) fmt.Formatter { + return formatter(func(w fmt.State, v rune) { d.formatUntil(w, v, now) }) +} + +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Duration(d)) +} + +func (d *Duration) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, (*time.Duration)(d)) +} + +func (d Duration) MarshalYAML() (interface{}, error) { + return time.Duration(d).String(), nil +} + +func (d *Duration) UnmarshalYAML(y *yaml.Node) error { + var s string + if err := y.Decode(&s); err != nil { + return err + } + p, err := time.ParseDuration(s) + if err != nil { + return err + } + *d = Duration(p) + return nil +} + +func (d Duration) MarshalText() ([]byte, error) { + return []byte(d.Text(time.Now())), nil +} + +func (d *Duration) UnmarshalText(b []byte) error { + p, err := ParseDuration(string(b)) + if err != nil { + return err + } + *d = p + return nil +} + +func (d Duration) Nanoseconds() int { return int(d) } + +func (d Duration) Microseconds() int { return int(d) / int(Microsecond) } + +func (d Duration) Milliseconds() int { return int(d) / int(Millisecond) } + +func (d Duration) Seconds() int { return int(d) / int(Second) } + +func (d Duration) Minutes() int { return int(d) / int(Minute) } + +func (d Duration) Hours() int { return int(d) / int(Hour) } + +func (d Duration) Days() int { return int(d) / int(Day) } + +func (d Duration) Weeks() int { return int(d) / int(Week) } + +func (d Duration) Months(until time.Time) int { + if d < 0 { + return -((-d).Months(until.Add(-time.Duration(d)))) + } + + cursor := until.Add(-time.Duration(d + 1)) + months := 0 + + for cursor.Before(until) { + cursor = cursor.AddDate(0, 1, 0) + months++ + } + + return months - 1 +} + +func (d Duration) Years(until time.Time) int { + if d < 0 { + return -((-d).Years(until.Add(-time.Duration(d)))) + } + + cursor := until.Add(-time.Duration(d + 1)) + years := 0 + + for cursor.Before(until) { + cursor = cursor.AddDate(1, 0, 0) + years++ + } + + return years - 1 +} + +var ( + _ fmt.Formatter = Duration(0) + _ fmt.GoStringer = Duration(0) + _ fmt.Stringer = Duration(0) + + _ json.Marshaler = Duration(0) + _ json.Unmarshaler = (*Duration)(nil) + + _ yaml.Marshaler = Duration(0) + _ yaml.Unmarshaler = (*Duration)(nil) + + _ encoding.TextMarshaler = Duration(0) + _ encoding.TextUnmarshaler = (*Duration)(nil) +) diff --git a/internal/print/human/duration_test.go b/internal/print/human/duration_test.go new file mode 100644 index 00000000..7d9c47ae --- /dev/null +++ b/internal/print/human/duration_test.go @@ -0,0 +1,125 @@ +package human + +import ( + "encoding/json" + "fmt" + "testing" + + yaml "gopkg.in/yaml.v3" +) + +func TestDurationParse(t *testing.T) { + for _, test := range []struct { + in string + out Duration + }{ + {in: "0", out: 0}, + + {in: "1ns", out: Nanosecond}, + {in: "1µs", out: Microsecond}, + {in: "1ms", out: Millisecond}, + {in: "1s", out: Second}, + {in: "1m", out: Minute}, + {in: "1h", out: Hour}, + + {in: "1d", out: 24 * Hour}, + {in: "2d", out: 48 * Hour}, + {in: "1w", out: 7 * 24 * Hour}, + {in: "2w", out: 14 * 24 * Hour}, + + {in: "1 nanosecond", out: Nanosecond}, + {in: "1 microsecond", out: Microsecond}, + {in: "1 millisecond", out: Millisecond}, + {in: "1 second", out: Second}, + {in: "1 minute", out: Minute}, + {in: "1 hour", out: Hour}, + + {in: "1 day", out: 24 * Hour}, + {in: "2 days", out: 48 * Hour}, + {in: "1 week", out: 7 * 24 * Hour}, + {in: "2 weeks", out: 14 * 24 * Hour}, + + {in: "1m30s", out: 1*Minute + 30*Second}, + {in: "1.5m", out: 1*Minute + 30*Second}, + } { + t.Run(test.in, func(t *testing.T) { + d, err := ParseDuration(test.in) + if err != nil { + t.Fatal(err) + } + if d != test.out { + t.Error("parsed duration mismatch:", d, "!=", test.out) + } + }) + } +} + +func TestDurationError(t *testing.T) { + _, err := ParseDuration("10") + if err == nil { + t.Fatal(err, "ParseDuration(10), expected error, got nil") + } + if want := "please include a unit ('weeks', 'h', 'm') in addition to the value (10.000000)"; err.Error() != want { + t.Errorf(`ParseDuration("10"), got %q, want %q`, err.Error(), want) + } +} + +func TestDurationFormat(t *testing.T) { + for _, test := range []struct { + in Duration + fmt string + out string + }{ + {fmt: "%v", out: "0s", in: 0}, + + {fmt: "%v", out: "1ns", in: Nanosecond}, + {fmt: "%v", out: "1µs", in: Microsecond}, + {fmt: "%v", out: "1ms", in: Millisecond}, + {fmt: "%v", out: "1s", in: Second}, + {fmt: "%v", out: "1m", in: Minute}, + {fmt: "%v", out: "1h", in: Hour}, + + {fmt: "%v", out: "1d", in: 24 * Hour}, + {fmt: "%v", out: "2d", in: 48 * Hour}, + {fmt: "%v", out: "1w", in: 7 * 24 * Hour}, + {fmt: "%v", out: "2w", in: 14 * 24 * Hour}, + {fmt: "%v", out: "1mo", in: 33 * 24 * Hour}, + {fmt: "%v", out: "2mo", in: 66 * 24 * Hour}, + {fmt: "%v", out: "1y", in: 400 * 24 * Hour}, + {fmt: "%v", out: "2y", in: 800 * 24 * Hour}, + + {fmt: "%v", out: "1m", in: 1*Minute + 30*Second}, + {fmt: "%+.1v", out: "2 hours", in: 2*Hour + 1*Minute + 30*Second}, + {fmt: "%+.2v", out: "2 hours 1 minute", in: 2*Hour + 1*Minute + 30*Second}, + {fmt: "%+.3v", out: "2 hours 1 minute 30 seconds", in: 2*Hour + 1*Minute + 30*Second}, + {fmt: "%#v", out: "human.Duration(60000000000)", in: 1 * Minute}, + } { + t.Run(test.out, func(t *testing.T) { + if s := fmt.Sprintf(test.fmt, test.in); s != test.out { + t.Error("duration string mismatch:", s, "!=", test.out) + } + }) + } +} + +func TestDurationJSON(t *testing.T) { + testDurationEncoding(t, (2 * Hour), json.Marshal, json.Unmarshal) +} + +func TestDurationYAML(t *testing.T) { + testDurationEncoding(t, (2 * Hour), yaml.Marshal, yaml.Unmarshal) +} + +func testDurationEncoding(t *testing.T, x Duration, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { + b, err := marshal(x) + if err != nil { + t.Fatal("marshal error:", err) + } + + v := Duration(0) + if err := unmarshal(b, &v); err != nil { + t.Error("unmarshal error:", err) + } else if v != x { + t.Error("value mismatch:", v, "!=", x) + } +} diff --git a/internal/print/human/human.go b/internal/print/human/human.go new file mode 100644 index 00000000..7d9df15d --- /dev/null +++ b/internal/print/human/human.go @@ -0,0 +1,206 @@ +// Package human provides types that support parsing and formatting +// human-friendly representations of values in various units. +// +// The package only exposes type names that are not that common to find in Go +// programs (in our experience). For that reason, it can be interesting to +// import the package as '.' (dot) to inject the symbols in the namespace of the +// importer, especially in the common case where it's being used in the main +// package of a program, for example: +// +// import ( +// . "github.com/segmentio/cli/human" +// ) +// +// This can help improve code readability by importing constants in the package +// namespace, allowing constructs like: +// +// type clientConfig{ +// DialTimeout Duration +// BufferSize Bytes +// RateLimit Rate +// } +// ... +// config := clientConfig{ +// DialTimeout: 10 * Second, +// BufferSize: 64 * KiB, +// RateLimit: 20 * PerSecond, +// } +package human + +import ( + "fmt" + "strconv" + "strings" + "unicode" +) + +func isDot(r rune) bool { + return r == '.' +} + +func isExp(r rune) bool { + return r == 'e' || r == 'E' +} + +func isSign(r rune) bool { + return r == '-' || r == '+' +} + +func isNumberPrefix(r rune) bool { + return isSign(r) || unicode.IsDigit(r) +} + +func hasPrefixFunc(s string, f func(rune) bool) bool { + for _, r := range s { + return f(r) + } + return false +} + +func countPrefixFunc(s string, f func(rune) bool) int { + var i int + var r rune + + terminated := false + for i, r = range s { + if !f(r) { + terminated = true + break + } + } + if !terminated { + return i + 1 + } + + return i +} + +func skipSpaces(s string) string { + return strings.TrimLeftFunc(s, unicode.IsSpace) +} + +func trimSpaces(s string) string { + return strings.TrimRightFunc(s, unicode.IsSpace) +} + +func parseNextNumber(s string) (string, string) { + i := 0 + + // integer part + i += countPrefixFunc(s[i:], isSign) // - or + + i += countPrefixFunc(s[i:], unicode.IsDigit) + + // Count all of the digits after the decimal (if one exists) + if hasPrefixFunc(s[i:], isDot) { + i++ // . + i += countPrefixFunc(s[i:], unicode.IsDigit) + } + + // exponent part + if hasPrefixFunc(s[i:], isExp) { + i++ // e or E + i += countPrefixFunc(s[i:], isSign) // - or + + i += countPrefixFunc(s[i:], unicode.IsDigit) + } + + return s[:i], skipSpaces(s[i:]) +} + +func parseNextToken(s string) (string, string) { + if hasPrefixFunc(s, isNumberPrefix) { + return parseNextNumber(s) + } + + for i, r := range s { + if isNumberPrefix(r) || unicode.IsSpace(r) { + return s[:i], skipSpaces(s[i:]) + } + } + + return s, "" +} + +// parseFloat tries to parse a number at the beginning of s, and returns the +// remainder as well as any error that occurs. +func parseFloat(s string) (float64, string, error) { + s, r := parseNextNumber(s) + f, err := strconv.ParseFloat(s, 64) + return f, r, err +} + +func parseUnit(s string) (head, unit string) { + i := strings.LastIndexFunc(s, func(r rune) bool { + return !unicode.IsLetter(r) + }) + + if i < 0 { + head = s + return + } + + head = trimSpaces(s[:i+1]) + unit = s[i+1:] + return +} + +func match(s, pattern string) bool { + return len(s) <= len(pattern) && strings.EqualFold(s, pattern[:len(s)]) +} + +type suffix byte + +func (c suffix) trim(s string) string { + for len(s) > 0 && s[len(s)-1] == byte(c) { + s = s[:len(s)-1] + } + return s +} + +func (c suffix) match(s string) bool { + return len(s) > 0 && s[len(s)-1] == byte(c) +} + +func fabs(value float64) float64 { + if value < 0 { + return -value + } + return value +} + +func ftoa(value, scale float64) string { + var format string + + if value == 0 { + return "0" + } + + if value < 0 { + return "-" + ftoa(-value, scale) + } + + switch { + case (value / scale) >= 100: + format = "%.0f" + case (value / scale) >= 10: + format = "%.1f" + case scale > 1: + format = "%.2f" + default: + format = "%.3f" + } + + s := fmt.Sprintf(format, value/scale) + if strings.Contains(s, ".") { + s = suffix('0').trim(s) + s = suffix('.').trim(s) + } + return s +} + +func printError(verb rune, typ, val interface{}) string { + return fmt.Sprintf("%%!%c(%T=%v)", verb, typ, val) +} + +type formatter func(fmt.State, rune) + +func (f formatter) Format(w fmt.State, v rune) { f(w, v) } diff --git a/internal/print/human/human_test.go b/internal/print/human/human_test.go new file mode 100644 index 00000000..226883ad --- /dev/null +++ b/internal/print/human/human_test.go @@ -0,0 +1,51 @@ +package human + +import ( + "testing" +) + +func TestParseNextToken(t *testing.T) { + for _, test := range []struct { + in string + head string + tail string + }{ + {in: "", head: "", tail: ""}, + {in: "a", head: "a", tail: ""}, + {in: "a b c", head: "a", tail: "b c"}, + {in: "abc123", head: "abc", tail: "123"}, + {in: "abc-123", head: "abc", tail: "-123"}, + {in: "abc+123", head: "abc", tail: "+123"}, + {in: "abc 123", head: "abc", tail: "123"}, + {in: "123abc", head: "123", tail: "abc"}, + {in: "+123abc", head: "+123", tail: "abc"}, + {in: "-123abc", head: "-123", tail: "abc"}, + {in: "123 abc", head: "123", tail: "abc"}, + {in: "123.abc", head: "123.", tail: "abc"}, + {in: "123.456abc", head: "123.456", tail: "abc"}, + {in: "123e4abc", head: "123e4", tail: "abc"}, + {in: "123E4abc", head: "123E4", tail: "abc"}, + {in: "-123.4e+56abc", head: "-123.4e+56", tail: "abc"}, + } { + t.Run("", func(t *testing.T) { + head, tail := parseNextToken(test.in) + if head != test.head { + t.Errorf("head mismatch: %q != %q", head, test.head) + } + if tail != test.tail { + t.Errorf("tail mismatch: %q != %q", tail, test.tail) + } + }) + } +} + +func TestParseFloat(t *testing.T) { + in := "10" + n, _, err := parseFloat(in) + if err != nil { + t.Fatalf("parseFloat(%q): got %q, want nil", in, err) + } + if want := float64(10); n != want { + t.Fatalf("parseFloat(%q): got %f, want %f", in, n, want) + } +} diff --git a/internal/print/human/number.go b/internal/print/human/number.go new file mode 100644 index 00000000..573d21a9 --- /dev/null +++ b/internal/print/human/number.go @@ -0,0 +1,146 @@ +package human + +import ( + "encoding" + "encoding/json" + "fmt" + "io" + "math" + "strconv" + "strings" + + yaml "gopkg.in/yaml.v3" +) + +// Number is similar to Count, but supports values with separators for +// readability purposes. +// +// The type supports parsing and formatting values likes: +// +// 123 +// 1.5 +// 2,000,000 +// ... +type Number float64 + +func ParseNumber(s string) (Number, error) { + r := strings.ReplaceAll(s, ",", "") + f, err := strconv.ParseFloat(r, 64) + if err != nil { + return 0, fmt.Errorf("malformed number: %s: %w", s, err) + } + return Number(f), nil +} + +func (n Number) String() string { + if n == 0 { + return "0" + } + + if n < 0 { + return "-" + (-n).String() + } + + if n <= 1e-3 || n >= 1e12 { + return strconv.FormatFloat(float64(n), 'g', -1, 64) + } + + i, d := math.Modf(float64(n)) + parts := make([]string, 0, 4) + + for u := uint64(i); u > 0; u /= 1000 { + parts = append(parts, strconv.FormatUint(u%1000, 10)) + } + + for i, j := 0, len(parts)-1; i < j; { + parts[i], parts[j] = parts[j], parts[i] + i++ + j-- + } + + r := strings.Join(parts, ",") + + if d != 0 { + r += "." + r += suffix('0').trim(strconv.FormatUint(uint64(math.Round(d*1000)), 10)) + } + + return r +} + +func (n Number) GoString() string { + return fmt.Sprintf("human.Number(%v)", float64(n)) +} + +// Format satisfies the fmt.Formatter interface. +// +// The method supports the following formatting verbs: +// +// e base 10, separator-free, scientific notation +// f base 10, separator-free, decimal notation +// g base 10, separator-free, act like 'e' or 'f' depending on scale +// s base 10, with separators (same as calling String) +// v same as the 's' format, unless '#' is set to print the go value +func (n Number) Format(w fmt.State, v rune) { + _, _ = io.WriteString(w, n.format(w, v)) +} + +func (n Number) format(w fmt.State, v rune) string { + switch v { + case 'e', 'f', 'g': + return strconv.FormatFloat(float64(n), byte(v), -1, 64) + case 's': + return n.String() + case 'v': + if w.Flag('#') { + return n.GoString() + } + return n.format(w, 's') + default: + return printError(v, n, float64(n)) + } +} + +func (n Number) MarshalJSON() ([]byte, error) { + return json.Marshal(float64(n)) +} + +func (n *Number) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, (*float64)(n)) +} + +func (n Number) MarshalYAML() (interface{}, error) { + return float64(n), nil +} + +func (n *Number) UnmarshalYAML(y *yaml.Node) error { + return y.Decode((*float64)(n)) +} + +func (n Number) MarshalText() ([]byte, error) { + return []byte(n.String()), nil +} + +func (n *Number) UnmarshalText(b []byte) error { + p, err := ParseNumber(string(b)) + if err != nil { + return err + } + *n = p + return nil +} + +var ( + _ fmt.Formatter = Number(0) + _ fmt.GoStringer = Number(0) + _ fmt.Stringer = Number(0) + + _ json.Marshaler = Number(0) + _ json.Unmarshaler = (*Number)(nil) + + _ yaml.Marshaler = Number(0) + _ yaml.Unmarshaler = (*Number)(nil) + + _ encoding.TextMarshaler = Number(0) + _ encoding.TextUnmarshaler = (*Number)(nil) +) diff --git a/internal/print/human/number_test.go b/internal/print/human/number_test.go new file mode 100644 index 00000000..5e94b2cb --- /dev/null +++ b/internal/print/human/number_test.go @@ -0,0 +1,76 @@ +package human + +import ( + "encoding/json" + "fmt" + "testing" + + yaml "gopkg.in/yaml.v3" +) + +func TestNumberParse(t *testing.T) { + for _, test := range []struct { + in string + out Number + }{ + {in: "0", out: 0}, + {in: "1234", out: 1234}, + {in: "1,234", out: 1234}, + {in: "1,234.567", out: 1234.567}, + } { + t.Run(test.in, func(t *testing.T) { + n, err := ParseNumber(test.in) + if err != nil { + t.Fatal(err) + } + if n != test.out { + t.Error("parsed number mismatch:", n, "!=", test.out) + } + }) + } +} + +func TestNumberFormat(t *testing.T) { + for _, test := range []struct { + in Number + fmt string + out string + }{ + {in: 0, fmt: "%v", out: "0"}, + {in: 1234, fmt: "%v", out: "1,234"}, + {in: 1234.567, fmt: "%v", out: "1,234.567"}, + {in: 123456.789, fmt: "%v", out: "123,456.789"}, + {in: 1234567.89, fmt: "%v", out: "1,234,567.89"}, + {in: 1234567.89, fmt: "%f", out: "1234567.89"}, + {in: 1234567.89, fmt: "%s", out: "1,234,567.89"}, + {in: 1234567.89, fmt: "%#v", out: "human.Number(1.23456789e+06)"}, + } { + t.Run(test.out, func(t *testing.T) { + if s := fmt.Sprintf(test.fmt, test.in); s != test.out { + t.Error("formatted number mismatch:", s, "!=", test.out) + } + }) + } +} + +func TestNumberJSON(t *testing.T) { + testNumberEncoding(t, Number(1.234), json.Marshal, json.Unmarshal) +} + +func TestNumberYAML(t *testing.T) { + testNumberEncoding(t, Number(1.234), yaml.Marshal, yaml.Unmarshal) +} + +func testNumberEncoding(t *testing.T, x Number, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { + b, err := marshal(x) + if err != nil { + t.Fatal("marshal error:", err) + } + + v := Number(0) + if err := unmarshal(b, &v); err != nil { + t.Error("unmarshal error:", err) + } else if v != x { + t.Error("value mismatch:", v, "!=", x) + } +} diff --git a/internal/print/human/path.go b/internal/print/human/path.go new file mode 100644 index 00000000..314a4e4b --- /dev/null +++ b/internal/print/human/path.go @@ -0,0 +1,32 @@ +package human + +import ( + "bytes" + "os" + "os/user" + "path/filepath" +) + +// Path represents a path on the file system. +// +// The type interprets the special prefix "~/" as representing the home +// directory of the user that the program is running as. +type Path string + +func (p *Path) UnmarshalText(b []byte) error { + switch { + case bytes.HasPrefix(b, []byte{'~', filepath.Separator}): + home, ok := os.LookupEnv("HOME") + if !ok { + u, err := user.Current() + if err != nil { + return err + } + home = u.HomeDir + } + *p = Path(filepath.Join(home, string(b[2:]))) + default: + *p = Path(b) + } + return nil +} diff --git a/internal/print/human/path_test.go b/internal/print/human/path_test.go new file mode 100644 index 00000000..8766adfd --- /dev/null +++ b/internal/print/human/path_test.go @@ -0,0 +1,33 @@ +package human + +import ( + "os" + "path/filepath" + "testing" +) + +func TestPath(t *testing.T) { + separator := string([]byte{filepath.Separator}) + + tests := []struct { + in string + out Path + }{ + {in: ".", out: "."}, + {in: separator, out: Path(separator)}, + {in: filepath.Join(".", "hello", "world"), out: Path(filepath.Join(".", "hello", "world"))}, + {in: filepath.Join("~", "hello", "world"), out: Path(filepath.Join(os.Getenv("HOME"), "hello", "world"))}, + } + + for _, test := range tests { + t.Run(test.in, func(t *testing.T) { + path := Path("") + + if err := path.UnmarshalText([]byte(test.in)); err != nil { + t.Error(err) + } else if path != test.out { + t.Errorf("path mismatch: %q != %q", path, test.out) + } + }) + } +} diff --git a/internal/print/human/rate.go b/internal/print/human/rate.go new file mode 100644 index 00000000..81a1ec12 --- /dev/null +++ b/internal/print/human/rate.go @@ -0,0 +1,202 @@ +package human + +import ( + "encoding" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + "unicode" + + yaml "gopkg.in/yaml.v3" +) + +// Rate represents a count devided by a unit of time. +// +// The type supports parsing and formatting values like: +// +// 200/s +// 1 / minute +// 0.5/week +// ... +// +// Rate values are always stored in their per-second form in Go programs, +// and properly converted during parsing and formatting. +type Rate float64 + +const ( + PerNanosecond Rate = 1 / Rate(Nanosecond) + PerMicrosecond Rate = 1 / Rate(Microsecond) + PerMillisecond Rate = 1 / Rate(Millisecond) + PerSecond Rate = 1 / Rate(Second) + PerMinute Rate = 1 / Rate(Minute) + PerHour Rate = 1 / Rate(Hour) + PerDay Rate = 1 / Rate(Day) + PerWeek Rate = 1 / Rate(Week) +) + +func ParseRate(s string) (Rate, error) { + var text string + var unit string + var rate Rate + + if i := strings.IndexByte(s, '/'); i < 0 { + text = s + } else { + text = strings.TrimLeftFunc(s[:i], unicode.IsSpace) + unit = strings.TrimRightFunc(s[i+1:], unicode.IsSpace) + } + + c, err := ParseCount(text) + if err != nil { + return 0, fmt.Errorf("malformed rate representation: %q", s) + } + + switch { + case match(unit, "week"): + rate = PerWeek + case match(unit, "day"): + rate = PerDay + case match(unit, "hour"): + rate = PerHour + case match(unit, "minute"): + rate = PerMinute + case match(unit, "second"), unit == "": + rate = PerSecond + case match(unit, "millisecond"), unit == "ms", unit == "µs": + rate = PerMillisecond + case match(unit, "microsecond"), unit == "us": + rate = PerMicrosecond + case match(unit, "nanosecond"), unit == "ns": + rate = PerNanosecond + default: + return 0, fmt.Errorf("malformed unit representation: %q", s) + } + + return Rate(c) * (rate / PerSecond), nil +} + +func (r Rate) String() string { + return r.Text(Second) +} + +func (r Rate) GoString() string { + return fmt.Sprintf("human.Rate(%v)", float64(r)) +} + +// Format satisfies the fmt.Formatter interface. +// +// The method supports the following formatting verbs: +// +// e base 10, unit-less, scientific notation +// f base 10, unit-less, decimal notation +// g base 10, unit-less, act like 'e' or 'f' depending on scale +// s base 10, with units (same as calling String) +// v same as the 's' format, unless '#' is set to print the go value +func (r Rate) Format(w fmt.State, v rune) { + r.formatPer(w, v, Second) +} + +func (r Rate) formatPer(w fmt.State, v rune, d Duration) { + _, _ = io.WriteString(w, r.format(w, v, d)) +} + +func (r Rate) format(w fmt.State, v rune, d Duration) string { + switch v { + case 'e', 'f', 'g': + return strconv.FormatFloat(float64(r), byte(v), -1, 64) + case 's': + return r.Text(d) + case 'v': + if w.Flag('#') { + return r.GoString() + } + return r.format(w, 's', d) + default: + return printError(v, r, float64(r)) + } +} + +func (r Rate) Text(d Duration) string { + var unit string + + switch { + case d >= Week: + unit = "/w" + case d >= Day: + unit = "/d" + case d >= Hour: + unit = "/h" + case d >= Minute: + unit = "/m" + case d >= Second: + unit = "/s" + case d >= Millisecond: + unit = "/ms" + case d >= Microsecond: + unit = "/µs" + default: + unit = "/ns" + } + + r /= Rate(d) * PerSecond + return Count(r).String() + unit +} + +func (r Rate) Formatter(d Duration) fmt.Formatter { + return formatter(func(w fmt.State, v rune) { r.formatPer(w, v, d) }) +} + +func (r Rate) MarshalJSON() ([]byte, error) { + return json.Marshal(float64(r)) +} + +func (r *Rate) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, (*float64)(r)) +} + +func (r Rate) MarshalYAML() (interface{}, error) { + return r.String(), nil +} + +func (r *Rate) UnmarshalYAML(y *yaml.Node) error { + var s string + if err := y.Decode(&s); err != nil { + return err + } + p, err := ParseRate(s) + if err != nil { + return err + } + *r = p + return nil +} + +func (r Rate) MarshalText() ([]byte, error) { + return []byte(r.String()), nil +} + +func (r *Rate) UnmarshalText(b []byte) error { + p, err := ParseRate(string(b)) + if err != nil { + return err + } + *r = p + return nil +} + +var ( + _ fmt.Formatter = Rate(0) + _ fmt.GoStringer = Rate(0) + _ fmt.Stringer = Rate(0) + + _ json.Marshaler = Rate(0) + _ json.Unmarshaler = (*Rate)(nil) + + _ yaml.Marshaler = Rate(0) + _ yaml.Unmarshaler = (*Rate)(nil) + + _ encoding.TextMarshaler = Rate(0) + _ encoding.TextUnmarshaler = (*Rate)(nil) +) diff --git a/internal/print/human/rate_test.go b/internal/print/human/rate_test.go new file mode 100644 index 00000000..2c694e77 --- /dev/null +++ b/internal/print/human/rate_test.go @@ -0,0 +1,77 @@ +package human + +import ( + "encoding/json" + "fmt" + "testing" + + yaml "gopkg.in/yaml.v3" +) + +func TestRateParse(t *testing.T) { + for _, test := range []struct { + in string + out Rate + }{ + {in: "0", out: 0}, + {in: "0/s", out: 0}, + {in: "1234/s", out: 1234}, + {in: "10.2K/s", out: 10200}, + } { + t.Run(test.in, func(t *testing.T) { + r, err := ParseRate(test.in) + if err != nil { + t.Fatal(err) + } + if r != test.out { + t.Error("parsed rate mismatch:", r, "!=", test.out) + } + }) + } +} + +func TestRateFormat(t *testing.T) { + for _, test := range []struct { + in Rate + fmt string + out string + unit Duration + }{ + {in: 0, fmt: "%v", out: "0/s", unit: Second}, + {in: 1234, fmt: "%v", out: "1234/s", unit: Second}, + {in: 10234, fmt: "%v", out: "10.2K/s", unit: Second}, + {in: 0.1, fmt: "%v", out: "100/ms", unit: Millisecond}, + {in: 604800, fmt: "%v", out: "1/w", unit: Week}, + {in: 1512000, fmt: "%v", out: "2.5/w", unit: Week}, + {in: 25, fmt: "%s", out: "25/s", unit: Second}, + {in: 25, fmt: "%#v", out: "human.Rate(25)", unit: Second}, + } { + t.Run(test.out, func(t *testing.T) { + if s := fmt.Sprintf(test.fmt, test.in.Formatter(test.unit)); s != test.out { + t.Error("formatted rate mismatch:", s, "!=", test.out) + } + }) + } +} + +func TestRateJSON(t *testing.T) { + testRateEncoding(t, Rate(1.234), json.Marshal, json.Unmarshal) +} + +func TestRateYAML(t *testing.T) { + testRateEncoding(t, Rate(1.234), yaml.Marshal, yaml.Unmarshal) +} + +func testRateEncoding(t *testing.T, x Rate, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { + b, err := marshal(x) + if err != nil { + t.Fatal("marshal error:", err) + } + + v := Rate(0) + if err := unmarshal(b, &v); err != nil { + t.Error("unmarshal error:", err) + } else if v != x { + t.Error("value mismatch:", v, "!=", x) + } +} diff --git a/internal/print/human/ratio.go b/internal/print/human/ratio.go new file mode 100644 index 00000000..9c0fc467 --- /dev/null +++ b/internal/print/human/ratio.go @@ -0,0 +1,145 @@ +package human + +import ( + "encoding" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + + yaml "gopkg.in/yaml.v3" +) + +// Ratio represents percentage-like values. +// +// The type supports parsing and formatting values like: +// +// 0.1 +// 25% +// 0.5 % +// ... +// +// Ratio values are stored as floating pointer numbers between 0 and 1 (assuming +// they stay within the 0-100% bounds), and formatted as percentages. +type Ratio float64 + +func ParseRatio(s string) (Ratio, error) { + k := 1.0 + p := suffix('%') + + if p.match(s) { + k = 100.0 + s = trimSpaces(s[:len(s)-1]) + } + + f, err := strconv.ParseFloat(s, 64) + return Ratio(f / k), err +} + +func (r Ratio) String() string { + return r.Text(2) +} + +func (r Ratio) GoString() string { + return fmt.Sprintf("human.Ratio(%v)", float64(r)) +} + +// Format satisfies the fmt.Formatter interface. +// +// The method supports the following formatting verbs: +// +// e base 10, unit-less, scientific notation +// f base 10, unit-less, decimal notation +// g base 10, unit-less, act like 'e' or 'f' depending on scale +// s base 10, with units (same as calling String) +// v same as the 's' format, unless '#' is set to print the go value +func (r Ratio) Format(w fmt.State, v rune) { + r.formatWith(w, v, 2) +} + +func (r Ratio) formatWith(w fmt.State, v rune, p int) { + _, _ = io.WriteString(w, r.format(w, v, p)) +} + +func (r Ratio) format(w fmt.State, v rune, p int) string { + switch v { + case 'e', 'f', 'g': + return strconv.FormatFloat(float64(r), byte(v), -1, 64) + case 's': + return r.Text(p) + case 'v': + if w.Flag('#') { + return r.GoString() + } + return r.format(w, 's', p) + default: + return printError(v, r, float64(r)) + } +} + +func (r Ratio) Text(precision int) string { + s := strconv.FormatFloat(100*float64(r), 'f', precision, 64) + if strings.Contains(s, ".") { + s = suffix('0').trim(s) + s = suffix('.').trim(s) + } + return s + "%" +} + +func (r Ratio) Formatter(precision int) fmt.Formatter { + return formatter(func(w fmt.State, v rune) { r.formatWith(w, v, precision) }) +} + +func (r Ratio) MarshalJSON() ([]byte, error) { + return json.Marshal(float64(r)) +} + +func (r *Ratio) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, (*float64)(r)) +} + +func (r Ratio) MarshalYAML() (interface{}, error) { + return r.Text(-1), nil +} + +func (r *Ratio) UnmarshalYAML(y *yaml.Node) error { + var s string + if err := y.Decode(&s); err != nil { + return err + } + p, err := ParseRatio(s) + if err != nil { + return err + } + *r = Ratio(p) + return nil +} + +func (r Ratio) MarshalText() ([]byte, error) { + return []byte(r.String()), nil +} + +func (r *Ratio) UnmarshalText(b []byte) error { + p, err := ParseRatio(string(b)) + if err != nil { + return err + } + *r = p + return nil +} + +var ( + _ fmt.Formatter = Ratio(0) + _ fmt.GoStringer = Ratio(0) + _ fmt.Stringer = Ratio(0) + + _ json.Marshaler = Ratio(0) + _ json.Unmarshaler = (*Ratio)(nil) + + _ yaml.Marshaler = Ratio(0) + _ yaml.Unmarshaler = (*Ratio)(nil) + + _ encoding.TextMarshaler = Ratio(0) + _ encoding.TextUnmarshaler = (*Ratio)(nil) +) diff --git a/internal/print/human/ratio_test.go b/internal/print/human/ratio_test.go new file mode 100644 index 00000000..9885c5ab --- /dev/null +++ b/internal/print/human/ratio_test.go @@ -0,0 +1,75 @@ +package human + +import ( + "encoding/json" + "fmt" + "testing" + + yaml "gopkg.in/yaml.v3" +) + +func TestRatioParse(t *testing.T) { + for _, test := range []struct { + in string + out Ratio + }{ + {in: "0", out: 0}, + {in: "0%", out: 0}, + {in: "0.0%", out: 0}, + {in: "12.34%", out: 0.1234}, + {in: "100%", out: 1}, + {in: "200%", out: 2}, + } { + t.Run(test.in, func(t *testing.T) { + n, err := ParseRatio(test.in) + if err != nil { + t.Fatal(err) + } + if n != test.out { + t.Error("parsed ratio mismatch:", n, "!=", test.out) + } + }) + } +} + +func TestRatioFormat(t *testing.T) { + for _, test := range []struct { + in Ratio + fmt string + out string + }{ + {in: 0, fmt: "%v", out: "0%"}, + {in: 0.1234, fmt: "%v", out: "12.34%"}, + {in: 1, fmt: "%v", out: "100%"}, + {in: 2, fmt: "%v", out: "200%"}, + {in: 0.234, fmt: "%#v", out: "human.Ratio(0.234)"}, + } { + t.Run(test.out, func(t *testing.T) { + if s := fmt.Sprintf(test.fmt, test.in); s != test.out { + t.Error("formatted ratio mismatch:", s, "!=", test.out) + } + }) + } +} + +func TestRatioJSON(t *testing.T) { + testRatioEncoding(t, Ratio(0.234), json.Marshal, json.Unmarshal) +} + +func TestRatioYAML(t *testing.T) { + testRatioEncoding(t, Ratio(0.234), yaml.Marshal, yaml.Unmarshal) +} + +func testRatioEncoding(t *testing.T, x Ratio, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { + b, err := marshal(x) + if err != nil { + t.Fatal("marshal error:", err) + } + + v := Ratio(0) + if err := unmarshal(b, &v); err != nil { + t.Error("unmarshal error:", err) + } else if v != x { + t.Error("value mismatch:", v, "!=", x) + } +} diff --git a/internal/print/human/time.go b/internal/print/human/time.go new file mode 100644 index 00000000..14bf1180 --- /dev/null +++ b/internal/print/human/time.go @@ -0,0 +1,202 @@ +package human + +import ( + "encoding" + "encoding/json" + "fmt" + "io" + "strings" + "time" + "unicode" + + yaml "gopkg.in/yaml.v3" +) + +// Time represents absolute point in times. The implementation is based on +// time.Time. +// +// The type supports all default time formats provided by the standard time +// package, as well as parsing and formatting values relative to a given time +// point, for example: +// +// 5 minutes ago +// 1h later +// ... +type Time time.Time + +func ParseTime(s string) (Time, error) { + return ParseTimeAt(s, time.Now()) +} + +func ParseTimeAt(s string, now time.Time) (Time, error) { + if s == "now" { + return Time(now), nil + } + + if strings.HasSuffix(s, " ago") { + s = strings.TrimLeftFunc(s[:len(s)-4], unicode.IsSpace) + d, err := ParseDurationUntil(s, now) + if err != nil { + return Time{}, fmt.Errorf("malformed time representation: %q", s) + } + return Time(now.Add(-time.Duration(d))), nil + } + + if strings.HasSuffix(s, " later") { + s = strings.TrimRightFunc(s[:len(s)-6], unicode.IsSpace) + d, err := ParseDurationUntil(s, now) + if err != nil { + return Time{}, fmt.Errorf("malformed time representation: %q", s) + } + return Time(now.Add(time.Duration(d))), nil + } + + for _, format := range []string{ + time.ANSIC, + time.UnixDate, + time.RubyDate, + time.RFC822, + time.RFC822Z, + time.RFC850, + time.RFC1123, + time.RFC1123Z, + time.RFC3339, + time.RFC3339Nano, + time.Kitchen, + time.Stamp, + time.StampMilli, + time.StampMicro, + time.StampNano, + } { + t, err := time.Parse(format, s) + if err == nil { + return Time(t), nil + } + } + + return Time{}, fmt.Errorf("unsupported time representation: %q", s) +} + +func (t Time) IsZero() bool { + return time.Time(t).IsZero() +} + +func (t Time) String() string { + return t.text(time.Now(), Duration.String) +} + +func (t Time) GoString() string { + return fmt.Sprintf("human.Time{s:%d,ns:%d}", + time.Time(t).Unix(), + time.Time(t).Nanosecond()) +} + +// Format satisfies the fmt.Formatter interface. +// +// The method supports the following formatting verbs: +// +// s duration relative to now (same as calling String) +// v sam as the 's' format, unless '#' is set to print the go value +// +// The 's' and 'v' formatting verbs also interpret the options: +// +// '-' outputs full names of the time units instead of abbreviations +// '.' followed by a digit to limit the precision of the output +func (t Time) Format(w fmt.State, v rune) { + t.formatAt(w, v, time.Now()) +} + +func (t Time) formatAt(w fmt.State, v rune, now time.Time) { + _, _ = io.WriteString(w, t.format(w, v, now)) +} + +func (t Time) format(w fmt.State, v rune, now time.Time) string { + switch v { + case 's': + return t.text(now, func(d Duration) string { return d.format(w, v, now) }) + case 'v': + if w.Flag('#') { + return t.GoString() + } + return t.format(w, 's', now) + default: + return printError(v, t, time.Time(t)) + } +} + +func (t Time) Text(now time.Time) string { + return t.text(now, func(d Duration) string { return d.Text(now) }) +} + +func (t Time) text(now time.Time, format func(Duration) string) string { + if t.IsZero() { + return "(none)" + } + d := Duration(now.Sub(time.Time(t))) + switch { + case d > 0: + return format(d) + " ago" + case d < 0: + return format(-d) + " later" + default: + return "now" + } +} + +func (t Time) Formatter(now time.Time) fmt.Formatter { + return formatter(func(w fmt.State, v rune) { t.formatAt(w, v, now) }) +} + +func (t Time) MarshalJSON() ([]byte, error) { + return time.Time(t).MarshalJSON() +} + +func (t *Time) UnmarshalJSON(b []byte) error { + return ((*time.Time)(t)).UnmarshalJSON(b) +} + +func (t Time) MarshalYAML() (interface{}, error) { + return time.Time(t).Format(time.RFC3339Nano), nil +} + +func (t *Time) UnmarshalYAML(y *yaml.Node) error { + var s string + if err := y.Decode(&s); err != nil { + return err + } + p, err := time.Parse(time.RFC3339Nano, s) + if err != nil { + return err + } + *t = Time(p) + return nil +} + +func (t Time) MarshalText() ([]byte, error) { + return []byte(t.Text(time.Now())), nil +} + +func (t *Time) UnmarshalText(b []byte) error { + p, err := ParseTime(string(b)) + if err != nil { + return err + } + *t = p + return nil +} + +var ( + _ fmt.Formatter = Time{} + _ fmt.GoStringer = Time{} + _ fmt.Stringer = Time{} + + _ json.Marshaler = Time{} + _ json.Unmarshaler = (*Time)(nil) + + _ yaml.IsZeroer = Time{} + _ yaml.Marshaler = Time{} + _ yaml.Unmarshaler = (*Time)(nil) + + _ encoding.TextMarshaler = Time{} + _ encoding.TextUnmarshaler = (*Time)(nil) +) diff --git a/internal/print/human/time_test.go b/internal/print/human/time_test.go new file mode 100644 index 00000000..600ccf12 --- /dev/null +++ b/internal/print/human/time_test.go @@ -0,0 +1,153 @@ +package human + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + yaml "gopkg.in/yaml.v3" +) + +func TestTimeParse(t *testing.T) { + now := time.Now() + end := now.Add(1 * time.Second) + + for _, test := range []struct { + in string + out Duration + }{ + {in: "now", out: 0}, + + {in: "1ns ago", out: -Nanosecond}, + {in: "1µs ago", out: -Microsecond}, + {in: "1ms ago", out: -Millisecond}, + {in: "1s ago", out: -Second}, + {in: "1m ago", out: -Minute}, + {in: "1h ago", out: -Hour}, + + {in: "1 nanosecond ago", out: -Nanosecond}, + {in: "1 microsecond ago", out: -Microsecond}, + {in: "1 millisecond ago", out: -Millisecond}, + {in: "1 second ago", out: -Second}, + {in: "1 minute ago", out: -Minute}, + {in: "1 hour ago", out: -Hour}, + + {in: "1 day ago", out: -24 * Hour}, + {in: "2 days ago", out: -48 * Hour}, + {in: "1 week ago", out: -7 * 24 * Hour}, + {in: "2 weeks ago", out: -14 * 24 * Hour}, + + {in: "0s later", out: 0}, + + {in: "1ns later", out: Nanosecond}, + {in: "1µs later", out: Microsecond}, + {in: "1ms later", out: Millisecond}, + {in: "1s later", out: Second}, + {in: "1m later", out: Minute}, + {in: "1h later", out: Hour}, + + {in: "1 nanosecond later", out: Nanosecond}, + {in: "1 microsecond later", out: Microsecond}, + {in: "1 millisecond later", out: Millisecond}, + {in: "1 second later", out: Second}, + {in: "1 minute later", out: Minute}, + {in: "1 hour later", out: Hour}, + + {in: "1 day later", out: 24 * Hour}, + {in: "2 days later", out: 48 * Hour}, + {in: "1 week later", out: 7 * 24 * Hour}, + {in: "2 weeks later", out: 14 * 24 * Hour}, + + {in: "1.5m ago", out: -1*Minute - 30*Second}, + + {in: end.Format(time.RFC3339Nano), out: 1 * Second}, + } { + t.Run(test.in, func(t *testing.T) { + p, err := ParseTimeAt(test.in, now) + if err != nil { + t.Fatal(err) + } + if d := Duration(time.Time(p).Sub(now)); d != test.out { + t.Error("parsed time delta mismatch:", d, "!=", test.out) + } + }) + } +} + +func TestTimeFormat(t *testing.T) { + now := time.Now() + + for _, test := range []struct { + in Duration + fmt string + out string + }{ + {fmt: "%v", out: "now", in: 0}, + + {fmt: "%v", out: "1ns ago", in: -Nanosecond}, + {fmt: "%v", out: "1µs ago", in: -Microsecond}, + {fmt: "%v", out: "1ms ago", in: -Millisecond}, + {fmt: "%v", out: "1s ago", in: -Second}, + {fmt: "%v", out: "1m ago", in: -Minute}, + {fmt: "%v", out: "1h ago", in: -Hour}, + + {fmt: "%v", out: "1d ago", in: -24 * Hour}, + {fmt: "%v", out: "2d ago", in: -48 * Hour}, + {fmt: "%v", out: "1w ago", in: -7 * 24 * Hour}, + {fmt: "%v", out: "2w ago", in: -14 * 24 * Hour}, + {fmt: "%v", out: "1mo ago", in: -33 * 24 * Hour}, + {fmt: "%v", out: "2mo ago", in: -66 * 24 * Hour}, + {fmt: "%v", out: "1y ago", in: -400 * 24 * Hour}, + {fmt: "%v", out: "2y ago", in: -800 * 24 * Hour}, + + {fmt: "%v", out: "1ns later", in: Nanosecond}, + {fmt: "%v", out: "1µs later", in: Microsecond}, + {fmt: "%v", out: "1ms later", in: Millisecond}, + {fmt: "%v", out: "1s later", in: Second}, + {fmt: "%v", out: "1m later", in: Minute}, + {fmt: "%v", out: "1h later", in: Hour}, + + {fmt: "%v", out: "1d later", in: 24 * Hour}, + {fmt: "%v", out: "2d later", in: 48 * Hour}, + {fmt: "%v", out: "1w later", in: 7 * 24 * Hour}, + {fmt: "%v", out: "2w later", in: 14 * 24 * Hour}, + {fmt: "%v", out: "1mo later", in: 33 * 24 * Hour}, + {fmt: "%v", out: "2mo later", in: 66 * 24 * Hour}, + {fmt: "%v", out: "1y later", in: 400 * 24 * Hour}, + {fmt: "%v", out: "2y later", in: 800 * 24 * Hour}, + + {fmt: "%v", out: "1m later", in: 1*Minute + 30*Second}, + {fmt: "%+.1v", out: "2 hours later", in: 2*Hour + 1*Minute + 30*Second}, + {fmt: "%+.2v", out: "2 hours 1 minute later", in: 2*Hour + 1*Minute + 30*Second}, + {fmt: "%+.3v", out: "2 hours 1 minute 30 seconds later", in: 2*Hour + 1*Minute + 30*Second}, + } { + t.Run(test.out, func(t *testing.T) { + if s := fmt.Sprintf(test.fmt, Time(now.Add(time.Duration(test.in))).Formatter(now)); s != test.out { + t.Error("time string mismatch:", s, "!=", test.out) + } + }) + } +} + +func TestTimeJSON(t *testing.T) { + testTimeEncoding(t, Time(time.Now()), json.Marshal, json.Unmarshal) +} + +func TestTimeYAML(t *testing.T) { + testTimeEncoding(t, Time(time.Now()), yaml.Marshal, yaml.Unmarshal) +} + +func testTimeEncoding(t *testing.T, x Time, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) { + b, err := marshal(x) + if err != nil { + t.Fatal("marshal error:", err) + } + + v := Time{} + if err := unmarshal(b, &v); err != nil { + t.Error("unmarshal error:", err) + } else if !time.Time(v).Equal(time.Time(x)) { + t.Error("value mismatch:", v, "!=", x) + } +} diff --git a/internal/print/jsonprint/writer.go b/internal/print/jsonprint/writer.go new file mode 100644 index 00000000..8d16034b --- /dev/null +++ b/internal/print/jsonprint/writer.go @@ -0,0 +1,30 @@ +package jsonprint + +import ( + "encoding/json" + "io" + + "github.com/stealthrocket/timecraft/internal/stream" +) + +func NewWriter[T any](w io.Writer) stream.WriteCloser[T] { + e := json.NewEncoder(w) + e.SetEscapeHTML(false) + e.SetIndent("", " ") + return writer[T]{e} +} + +type writer[T any] struct{ *json.Encoder } + +func (w writer[T]) Write(values []T) (int, error) { + for n := range values { + if err := w.Encode(values[n]); err != nil { + return n, err + } + } + return len(values), nil +} + +func (w writer[T]) Close() error { + return nil +} diff --git a/internal/print/jsonprint/writer_test.go b/internal/print/jsonprint/writer_test.go new file mode 100644 index 00000000..8e6e09ca --- /dev/null +++ b/internal/print/jsonprint/writer_test.go @@ -0,0 +1,46 @@ +package jsonprint_test + +import ( + "bytes" + "testing" + + "github.com/stealthrocket/timecraft/internal/assert" + "github.com/stealthrocket/timecraft/internal/print/jsonprint" +) + +type tag struct { + Name string `json:"name"` + Value string `json:"value"` +} + +func TestWriteNothing(t *testing.T) { + b := new(bytes.Buffer) + w := jsonprint.NewWriter[tag](b) + assert.OK(t, w.Close()) + assert.Equal(t, b.String(), "") +} + +func TestWriteValues(t *testing.T) { + b := new(bytes.Buffer) + w := jsonprint.NewWriter[tag](b) + _, err := w.Write([]tag{ + {Name: "one", Value: "1"}, + {Name: "two", Value: "2"}, + {Name: "three", Value: "3"}, + }) + assert.OK(t, err) + assert.OK(t, w.Close()) + assert.Equal(t, b.String(), `{ + "name": "one", + "value": "1" +} +{ + "name": "two", + "value": "2" +} +{ + "name": "three", + "value": "3" +} +`) +} diff --git a/internal/print/textprint/table.go b/internal/print/textprint/table.go new file mode 100644 index 00000000..c90031e7 --- /dev/null +++ b/internal/print/textprint/table.go @@ -0,0 +1,96 @@ +package textprint + +import ( + "io" + "reflect" + "strings" + "text/tabwriter" + + "github.com/stealthrocket/timecraft/internal/stream" +) + +func NewTableWriter[T any](w io.Writer) stream.WriteCloser[T] { + t := &tableWriter[T]{ + writer: tabwriter.NewWriter(w, 0, 4, 2, ' ', 0), + valueOf: func(values []T, index int) reflect.Value { + return reflect.ValueOf(&values[index]).Elem() + }, + } + + writeString := func(w io.Writer, s string) { + _, err := io.WriteString(w, s) + if err != nil { + panic(err) + } + } + + var v T + valueType := reflect.TypeOf(v) + if valueType.Kind() == reflect.Pointer { + valueType = valueType.Elem() + t.valueOf = func(values []T, index int) reflect.Value { + return reflect.ValueOf(values[index]).Elem() + } + } + + for i, f := range reflect.VisibleFields(valueType) { + if i != 0 { + writeString(t.writer, "\t") + } + + name := f.Name + if textTag := f.Tag.Get("text"); textTag != "" { + tag := strings.Split(textTag, ",") + name, tag = tag[0], tag[1:] + for _, s := range tag { + switch s { + // TODO: other tags + } + } + } + + if name == "-" { + continue + } + + writeString(t.writer, name) + t.encoders = append(t.encoders, encodeFuncOfStructField(f.Type, f.Index)) + } + + writeString(t.writer, "\n") + return t +} + +type tableWriter[T any] struct { + writer *tabwriter.Writer + encoders []encodeFunc + valueOf func([]T, int) reflect.Value +} + +func (t *tableWriter[T]) Write(values []T) (int, error) { + for n := range values { + v := t.valueOf(values, n) + w := io.Writer(t.writer) + + for i, enc := range t.encoders { + if i != 0 { + _, err := io.WriteString(w, "\t") + if err != nil { + return n, err + } + } + if err := enc(w, v); err != nil { + return n, err + } + } + + if _, err := io.WriteString(w, "\n"); err != nil { + return n, err + } + } + return len(values), nil +} + +func (t *tableWriter[T]) Close() error { + return t.writer.Flush() +} diff --git a/internal/print/textprint/table_test.go b/internal/print/textprint/table_test.go new file mode 100644 index 00000000..9eaee0cf --- /dev/null +++ b/internal/print/textprint/table_test.go @@ -0,0 +1,39 @@ +package textprint_test + +import ( + "bytes" + "testing" + + "github.com/stealthrocket/timecraft/internal/assert" + "github.com/stealthrocket/timecraft/internal/print/textprint" +) + +type person struct { + FirstName string `text:"FIRST NAME"` + LastName string `text:"LAST NAME"` + Age int `text:"AGE"` +} + +func TestTableWriteNothing(t *testing.T) { + b := new(bytes.Buffer) + w := textprint.NewTableWriter[person](b) + assert.OK(t, w.Close()) + assert.Equal(t, b.String(), "FIRST NAME LAST NAME AGE\n") +} + +func TestTableWriteValues(t *testing.T) { + b := new(bytes.Buffer) + w := textprint.NewTableWriter[person](b) + _, err := w.Write([]person{ + {FirstName: "Luke", LastName: "Skywalker", Age: 19}, + {FirstName: "Leia", LastName: "Skywalker", Age: 19}, + {FirstName: "Han", LastName: "Solo", Age: 19}, + }) + assert.OK(t, err) + assert.OK(t, w.Close()) + assert.Equal(t, b.String(), `FIRST NAME LAST NAME AGE +Luke Skywalker 19 +Leia Skywalker 19 +Han Solo 19 +`) +} diff --git a/internal/print/textprint/textprint.go b/internal/print/textprint/textprint.go new file mode 100644 index 00000000..312dfa07 --- /dev/null +++ b/internal/print/textprint/textprint.go @@ -0,0 +1,168 @@ +package textprint + +import ( + "fmt" + "io" + "reflect" + + "golang.org/x/exp/slices" +) + +type encodeFunc func(io.Writer, reflect.Value) error + +func encodeBool(w io.Writer, v reflect.Value) error { + _, err := fmt.Fprintf(w, "%t", v.Bool()) + return err +} + +func encodeInt(w io.Writer, v reflect.Value) error { + _, err := fmt.Fprintf(w, "%d", v.Int()) + return err +} + +func encodeUint(w io.Writer, v reflect.Value) error { + _, err := fmt.Fprintf(w, "%d", v.Uint()) + return err +} + +func encodeString(w io.Writer, v reflect.Value) error { + _, err := io.WriteString(w, v.String()) + return err +} + +func encodeStringer(w io.Writer, v reflect.Value) error { + _, err := io.WriteString(w, v.Interface().(fmt.Stringer).String()) + return err +} + +func encodeFormatter(w io.Writer, v reflect.Value) error { + _, err := fmt.Fprintf(w, "%v", v.Interface()) + return err +} + +func encodeFuncOf(t reflect.Type) encodeFunc { + if t.Implements(reflect.TypeOf((*fmt.Formatter)(nil)).Elem()) { + return encodeFormatter + } + if t.Implements(reflect.TypeOf((*fmt.Stringer)(nil)).Elem()) { + return encodeStringer + } + switch t.Kind() { + case reflect.Bool: + return encodeBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return encodeInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return encodeUint + case reflect.String: + return encodeString + case reflect.Pointer: + return encodeFuncOfPointer(t.Elem()) + case reflect.Slice: + return encodeFuncOfSlice(t.Elem()) + case reflect.Map: + return encodeFuncOfMap(t.Key(), t.Elem()) + default: + panic("cannot encode values of type " + t.String()) + } +} + +func encodeFuncOfPointer(t reflect.Type) encodeFunc { + encode := encodeFuncOf(t) + return func(w io.Writer, v reflect.Value) error { + if v.IsNil() { + _, err := fmt.Fprintf(w, "(none)") + return err + } else { + return encode(w, v.Elem()) + } + } +} + +func encodeFuncOfSlice(t reflect.Type) encodeFunc { + encode := encodeFuncOf(t) + return func(w io.Writer, v reflect.Value) error { + for i, n := 0, v.Len(); i < n; i++ { + if i != 0 { + if _, err := io.WriteString(w, ", "); err != nil { + return err + } + } + if err := encode(w, v.Index(i)); err != nil { + return err + } + } + return nil + } +} + +func encodeFuncOfMap(key, val reflect.Type) encodeFunc { + lessFunc := lessFuncOf(key) + encodeKey := encodeFuncOf(key) + encodeVal := encodeFuncOf(val) + return func(w io.Writer, v reflect.Value) error { + keys := v.MapKeys() + slices.SortFunc(keys, lessFunc) + + for i, key := range keys { + if i != 0 { + if _, err := io.WriteString(w, ", "); err != nil { + return err + } + } + if err := encodeKey(w, key); err != nil { + return err + } + if _, err := io.WriteString(w, ":"); err != nil { + return err + } + if err := encodeVal(w, v.MapIndex(key)); err != nil { + return err + } + } + + return nil + } +} + +func encodeFuncOfStructField(t reflect.Type, index []int) encodeFunc { + encode := encodeFuncOf(t) + return func(w io.Writer, v reflect.Value) error { + return encode(w, v.FieldByIndex(index)) + } +} + +type lessFunc func(reflect.Value, reflect.Value) bool + +func lessBool(v1, v2 reflect.Value) bool { + b1 := v1.Bool() + b2 := v2.Bool() + return !b1 && b1 != b2 +} + +func lessInt(v1, v2 reflect.Value) bool { + return v1.Int() < v2.Int() +} + +func lessUint(v1, v2 reflect.Value) bool { + return v1.Uint() < v2.Uint() +} + +func lessString(v1, v2 reflect.Value) bool { + return v1.String() < v2.String() +} + +func lessFuncOf(t reflect.Type) lessFunc { + switch t.Kind() { + case reflect.Bool: + return lessBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return lessInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return lessUint + case reflect.String: + return lessString + default: + panic("cannot compare values of type " + t.String()) + } +} diff --git a/internal/print/yamlprint/writer.go b/internal/print/yamlprint/writer.go new file mode 100644 index 00000000..e12c4157 --- /dev/null +++ b/internal/print/yamlprint/writer.go @@ -0,0 +1,36 @@ +package yamlprint + +import ( + "io" + + "gopkg.in/yaml.v3" + + "github.com/stealthrocket/timecraft/internal/stream" +) + +func NewWriter[T any](w io.Writer) stream.WriteCloser[T] { + e := yaml.NewEncoder(w) + e.SetIndent(2) + return writer[T]{e} +} + +type writer[T any] struct{ *yaml.Encoder } + +func (w writer[T]) Write(values []T) (int, error) { + for i := range values { + if err := w.Encode(values[i]); err != nil { + return i, err + } + } + return len(values), nil +} + +func (w writer[T]) Close() error { + err := w.Encoder.Close() + if err != nil { + if s := err.Error(); s == `yaml: expected STREAM-START` { + err = nil + } + } + return err +} diff --git a/internal/print/yamlprint/writer_test.go b/internal/print/yamlprint/writer_test.go new file mode 100644 index 00000000..ea061112 --- /dev/null +++ b/internal/print/yamlprint/writer_test.go @@ -0,0 +1,42 @@ +package yamlprint_test + +import ( + "bytes" + "testing" + + "github.com/stealthrocket/timecraft/internal/assert" + "github.com/stealthrocket/timecraft/internal/print/yamlprint" +) + +type tag struct { + Name string `yaml:"name"` + Value string `yaml:"value"` +} + +func TestWriteNothing(t *testing.T) { + b := new(bytes.Buffer) + w := yamlprint.NewWriter[tag](b) + assert.OK(t, w.Close()) + assert.Equal(t, b.String(), "") +} + +func TestWriteValues(t *testing.T) { + b := new(bytes.Buffer) + w := yamlprint.NewWriter[tag](b) + _, err := w.Write([]tag{ + {Name: "one", Value: "1"}, + {Name: "two", Value: "2"}, + {Name: "three", Value: "3"}, + }) + assert.OK(t, err) + assert.OK(t, w.Close()) + assert.Equal(t, b.String(), `name: one +value: "1" +--- +name: two +value: "2" +--- +name: three +value: "3" +`) +} diff --git a/internal/stream/convert.go b/internal/stream/convert.go new file mode 100644 index 00000000..61cce6c5 --- /dev/null +++ b/internal/stream/convert.go @@ -0,0 +1,63 @@ +package stream + +func ConvertReader[To, From any](base Reader[From], conv func(From) (To, error)) Reader[To] { + return &convertReader[To, From]{base: base, conv: conv} +} + +type convertReader[To, From any] struct { + base Reader[From] + from []From + conv func(From) (To, error) +} + +func (r *convertReader[To, From]) Read(values []To) (n int, err error) { + for n < len(values) { + if i := len(values) - n; i <= cap(r.from) { + r.from = r.from[:i] + } else { + r.from = make([]From, i) + } + + rn, err := r.base.Read(r.from) + + for _, from := range r.from[:rn] { + to, err := r.conv(from) + if err != nil { + return n, err + } + values[n] = to + n++ + } + + if err != nil { + return n, err + } + } + return n, nil +} + +func ConvertWriter[To, From any](base Writer[To], conv func(From) (To, error)) Writer[From] { + return &convertWriter[To, From]{base: base, conv: conv} +} + +type convertWriter[To, From any] struct { + base Writer[To] + to []To + conv func(From) (To, error) +} + +func (w *convertWriter[To, From]) Write(values []From) (n int, err error) { + defer func() { + w.to = w.to[:0] + }() + + for _, from := range values { + to, err := w.conv(from) + if err != nil { + return 0, err + } + w.to = append(w.to, to) + } + + return w.base.Write(w.to) +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 856bde5f..23d591af 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -4,8 +4,8 @@ package stream import "io" -// Reader is an interface implemented by types that produce a stream of values -// of type T. +// Reader is an interface implemented by types that read a stream of values of +// type T. type Reader[T any] interface { // Reads values from the stream, returning the number of values read and any // error that occurred. @@ -79,3 +79,63 @@ func ReadAll[T any](r Reader[T]) ([]T, error) { } } } + +// Writer is an interface implemented by types that write a stream of values of +// type T. +type Writer[T any] interface { + Write(values []T) (int, error) +} + +// WriteCloser represents a closable stream of values of T. +// +// WriteClosers is like io.WriteCloser for values of any type. +type WriteCloser[T any] interface { + Writer[T] + io.Closer +} + +func NewWriteCloser[T any](w Writer[T], c io.Closer) WriteCloser[T] { + return &writeCloser[T]{writer: w, closer: c} +} + +type writeCloser[T any] struct { + writer Writer[T] + closer io.Closer +} + +func (w *writeCloser[T]) Write(values []T) (int, error) { + return w.writer.Write(values) +} + +func (w *writeCloser[T]) Close() error { + return w.closer.Close() +} + +// Copy writes values read from r to w, returning the number of values written +// and any error other than io.EOF. +func Copy[T any](w Writer[T], r Reader[T]) (int64, error) { + b := make([]T, 20) + n := int64(0) + + for { + rn, err := r.Read(b) + + if rn > 0 { + wn, err := w.Write(b[:rn]) + n += int64(wn) + if err != nil { + return n, err + } + if wn < rn { + return n, io.ErrNoProgress + } + } + + if err != nil { + if err == io.EOF { + err = nil + } + return n, err + } + } +} diff --git a/internal/timemachine/registry.go b/internal/timemachine/registry.go index 962e4bef..ab31b933 100644 --- a/internal/timemachine/registry.go +++ b/internal/timemachine/registry.go @@ -3,13 +3,18 @@ package timemachine import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "io" "path" "strconv" + "strings" "time" + "golang.org/x/exp/slices" + "github.com/stealthrocket/timecraft/format" "github.com/stealthrocket/timecraft/internal/object" "github.com/stealthrocket/timecraft/internal/stream" @@ -21,40 +26,56 @@ var ( ErrNoLogRecords = errors.New("process has no records") ) -type ModuleInfo struct { - ID Hash - Size int64 - CreatedAt time.Time +type TimeRange struct { + Start, End time.Time +} + +func Between(then, now time.Time) TimeRange { + return TimeRange{Start: then, End: now} } -type LogSegmentInfo struct { +func Since(then time.Time) TimeRange { + return Between(then, time.Now().In(then.Location())) +} + +func Until(now time.Time) TimeRange { + return Between(time.Unix(0, 0).In(now.Location()), now) +} + +func (tr TimeRange) Duration() time.Duration { + return tr.End.Sub(tr.Start) +} + +type LogSegment struct { Number int Size int64 CreatedAt time.Time } type Registry struct { - objects object.Store -} - -func NewRegistry(objects object.Store) *Registry { - return &Registry{objects: objects} + // The object store that the registry uses to load and store data. + Store object.Store + // List of tags that are added to every object created by this registry. + CreateTags []object.Tag + // List of tags that are added to every query selecting objects from this + // registry. + SelectTags []object.Tag } -func (reg *Registry) CreateModule(ctx context.Context, module *format.Module) (*format.Descriptor, error) { - return reg.createObject(ctx, module) +func (reg *Registry) CreateModule(ctx context.Context, module *format.Module, tags ...object.Tag) (*format.Descriptor, error) { + return reg.createObject(ctx, module, tags) } -func (reg *Registry) CreateRuntime(ctx context.Context, runtime *format.Runtime) (*format.Descriptor, error) { - return reg.createObject(ctx, runtime) +func (reg *Registry) CreateRuntime(ctx context.Context, runtime *format.Runtime, tags ...object.Tag) (*format.Descriptor, error) { + return reg.createObject(ctx, runtime, tags) } -func (reg *Registry) CreateConfig(ctx context.Context, config *format.Config) (*format.Descriptor, error) { - return reg.createObject(ctx, config) +func (reg *Registry) CreateConfig(ctx context.Context, config *format.Config, tags ...object.Tag) (*format.Descriptor, error) { + return reg.createObject(ctx, config, tags) } -func (reg *Registry) CreateProcess(ctx context.Context, process *format.Process) (*format.Descriptor, error) { - return reg.createObject(ctx, process) +func (reg *Registry) CreateProcess(ctx context.Context, process *format.Process, tags ...object.Tag) (*format.Descriptor, error) { + return reg.createObject(ctx, process, tags) } func (reg *Registry) LookupModule(ctx context.Context, hash format.Hash) (*format.Module, error) { @@ -77,8 +98,24 @@ func (reg *Registry) LookupProcess(ctx context.Context, hash format.Hash) (*form return process, reg.lookupObject(ctx, hash, process) } -func (reg *Registry) LookupDescriptor(ctx context.Context, hash format.Hash) (*format.Descriptor, error) { - return reg.lookupDescriptor(ctx, reg.descriptorKey(hash)) +func (reg *Registry) ListModules(ctx context.Context, timeRange TimeRange, tags ...object.Tag) stream.ReadCloser[*format.Descriptor] { + return reg.listObjects(ctx, "module", timeRange, tags) +} + +func (reg *Registry) ListRuntimes(ctx context.Context, timeRange TimeRange, tags ...object.Tag) stream.ReadCloser[*format.Descriptor] { + return reg.listObjects(ctx, "runtime", timeRange, tags) +} + +func (reg *Registry) ListConfigs(ctx context.Context, timeRange TimeRange, tags ...object.Tag) stream.ReadCloser[*format.Descriptor] { + return reg.listObjects(ctx, "config", timeRange, tags) +} + +func (reg *Registry) ListProcesses(ctx context.Context, timeRange TimeRange, tags ...object.Tag) stream.ReadCloser[*format.Descriptor] { + return reg.listObjects(ctx, "process", timeRange, tags) +} + +func (reg *Registry) ListResources(ctx context.Context, resourceType string, timeRange TimeRange, tags ...object.Tag) stream.ReadCloser[*format.Descriptor] { + return reg.listObjects(ctx, resourceType, timeRange, tags) } func errorCreateObject(hash format.Hash, value format.Resource, err error) error { @@ -89,65 +126,108 @@ func errorLookupObject(hash format.Hash, value format.Resource, err error) error return fmt.Errorf("lookup object: %s: %s: %w", hash, value.ContentType(), err) } -func errorLookupDescriptor(hash format.Hash, value format.Resource, err error) error { - return fmt.Errorf("lookup descriptor: %s: %s: %w", hash, value.ContentType(), err) +func errorListObjects(mediaType format.MediaType, err error) error { + return fmt.Errorf("list objects: %s: %w", mediaType, err) } -func (reg *Registry) createObject(ctx context.Context, value format.ResourceMarshaler) (*format.Descriptor, error) { - b, err := value.MarshalResource() - if err != nil { - return nil, err +func resourceTypeOf(mediaType format.MediaType) string { + const prefix = "application/vnd.timecraft." + if strings.HasPrefix(string(mediaType), prefix) { + s, _, _ := strings.Cut(string(mediaType[len(prefix):]), ".") + return s } - hash := SHA256(b) - name := reg.objectKey(hash) - desc := reg.descriptorKey(hash) + return "unknown" +} - descriptor, err := reg.lookupDescriptor(ctx, desc) - if err == nil { - return descriptor, nil - } - if !errors.Is(err, object.ErrNotExist) { - return nil, errorLookupDescriptor(hash, value, err) +func appendTagFilters(filters []object.Filter, tags []object.Tag) []object.Filter { + for _, tag := range tags { + filters = append(filters, object.MATCH(tag.Name, tag.Value)) } + return filters +} - descriptor = &format.Descriptor{ - MediaType: value.ContentType(), - Digest: hash, - Size: int64(len(b)), - } - d, err := descriptor.MarshalResource() - if err != nil { - return nil, errorCreateObject(hash, value, err) +func assignTags(annotations map[string]string, tags []object.Tag) { + for _, tag := range tags { + annotations[tag.Name] = tag.Value } +} - if err := reg.objects.CreateObject(ctx, desc, bytes.NewReader(d)); err != nil { - return nil, errorCreateObject(hash, value, err) +func makeTags(annotations map[string]string) []object.Tag { + tags := make([]object.Tag, 0, len(annotations)) + for name, value := range annotations { + tags = append(tags, object.Tag{ + Name: name, + Value: value, + }) } - if err := reg.objects.CreateObject(ctx, name, bytes.NewReader(b)); err != nil { - return nil, errorCreateObject(hash, value, err) + slices.SortFunc(tags, func(t1, t2 object.Tag) bool { + return t1.Name < t2.Name + }) + return tags +} + +func sha256Hash(data []byte, tags []object.Tag) format.Hash { + buf := object.AppendTags(make([]byte, 0, 256), tags...) + sha := sha256.New() + sha.Write(data) + sha.Write(buf) + return format.Hash{ + Algorithm: "sha256", + Digest: hex.EncodeToString(sha.Sum(nil)), } - return descriptor, nil } -func (reg *Registry) lookupDescriptor(ctx context.Context, key string) (*format.Descriptor, error) { - r, err := reg.objects.ReadObject(ctx, key) +func (reg *Registry) createObject(ctx context.Context, value format.ResourceMarshaler, extraTags []object.Tag) (*format.Descriptor, error) { + b, err := value.MarshalResource() if err != nil { return nil, err } - defer r.Close() - b, err := io.ReadAll(r) - if err != nil { - return nil, err + mediaType := value.ContentType() + + annotations := make(map[string]string, 1+len(extraTags)+len(reg.CreateTags)) + assignTags(annotations, reg.CreateTags) + assignTags(annotations, extraTags) + assignTags(annotations, []object.Tag{ + { + Name: "timecraft.object.media-type", + Value: mediaType.String(), + }, + { + Name: "timecraft.object.created-at", + Value: time.Now().UTC().Format(time.RFC3339), + }, + { + Name: "timecraft.object.resource-type", + Value: resourceTypeOf(mediaType), + }, + }) + + tags := makeTags(annotations) + hash := sha256Hash(b, tags) + name := reg.objectKey(hash) + desc := &format.Descriptor{ + MediaType: mediaType, + Digest: hash, + Size: int64(len(b)), + Annotations: annotations, } - descriptor := new(format.Descriptor) - if err := descriptor.UnmarshalResource(b); err != nil { - return nil, err + + if _, err := reg.Store.StatObject(ctx, name); err != nil { + if !errors.Is(err, object.ErrNotExist) { + return nil, errorCreateObject(hash, value, err) + } + } else { + return desc, nil + } + + if err := reg.Store.CreateObject(ctx, name, bytes.NewReader(b), tags...); err != nil { + return nil, errorCreateObject(hash, value, err) } - return descriptor, nil + return desc, nil } func (reg *Registry) lookupObject(ctx context.Context, hash format.Hash, value format.ResourceUnmarshaler) error { - r, err := reg.objects.ReadObject(ctx, reg.objectKey(hash)) + r, err := reg.Store.ReadObject(ctx, reg.objectKey(hash)) if err != nil { return errorLookupObject(hash, value, err) } @@ -162,12 +242,44 @@ func (reg *Registry) lookupObject(ctx context.Context, hash format.Hash, value f return nil } -func (reg *Registry) descriptorKey(hash format.Hash) string { - return "obj/" + hash.String() + "/descriptor.json" +func (reg *Registry) listObjects(ctx context.Context, resourceType string, timeRange TimeRange, matchTags []object.Tag) stream.ReadCloser[*format.Descriptor] { + if !timeRange.Start.IsZero() { + timeRange.Start = timeRange.Start.Add(-1) + } + + filters := []object.Filter{ + object.MATCH("timecraft.object.resource-type", resourceType), + object.AFTER(timeRange.Start), + object.BEFORE(timeRange.End), + } + filters = appendTagFilters(filters, reg.SelectTags) + filters = appendTagFilters(filters, matchTags) + + reader := reg.Store.ListObjects(ctx, "obj/", filters...) + return convert(reader, func(info object.Info) (*format.Descriptor, error) { + m, ok := info.Lookup("timecraft.object.media-type") + if !ok { + m = "application/octet-stream" + } + mediaType := format.MediaType(m) + hash, err := format.ParseHash(path.Base(info.Name)) + if err != nil { + return nil, errorListObjects(mediaType, err) + } + desc := &format.Descriptor{ + MediaType: mediaType, + Digest: hash, + Size: info.Size, + Annotations: make(map[string]string, len(info.Tags)), + } + assignTags(desc.Annotations, info.Tags) + delete(desc.Annotations, "timecraft.object.media-type") + return desc, nil + }) } func (reg *Registry) objectKey(hash format.Hash) string { - return "obj/" + hash.String() + "/content" + return "obj/" + hash.String() } func (reg *Registry) CreateLogManifest(ctx context.Context, processID format.UUID, manifest *format.Manifest) error { @@ -175,7 +287,7 @@ func (reg *Registry) CreateLogManifest(ctx context.Context, processID format.UUI if err != nil { return err } - return reg.objects.CreateObject(ctx, reg.manifestKey(processID), bytes.NewReader(b)) + return reg.Store.CreateObject(ctx, reg.manifestKey(processID), bytes.NewReader(b)) } func (reg *Registry) CreateLogSegment(ctx context.Context, processID format.UUID, segmentNumber int) (io.WriteCloser, error) { @@ -184,7 +296,7 @@ func (reg *Registry) CreateLogSegment(ctx context.Context, processID format.UUID done := make(chan struct{}) go func() { defer close(done) - r.CloseWithError(reg.objects.CreateObject(ctx, name, r)) + r.CloseWithError(reg.Store.CreateObject(ctx, name, r)) }() return &logSegmentWriter{writer: w, done: done}, nil } @@ -204,14 +316,15 @@ func (w *logSegmentWriter) Close() error { return err } -func (reg *Registry) ListLogSegments(ctx context.Context, processID format.UUID) stream.Reader[LogSegmentInfo] { - return convert(reg.objects.ListObjects(ctx, "log/"+processID.String()+"/data"), func(info object.Info) (LogSegmentInfo, error) { +func (reg *Registry) ListLogSegments(ctx context.Context, processID format.UUID) stream.Reader[LogSegment] { + reader := reg.Store.ListObjects(ctx, "log/"+processID.String()+"/data") + return convert(reader, func(info object.Info) (LogSegment, error) { number := path.Base(info.Name) n, err := strconv.ParseInt(number, 16, 32) if err != nil || n < 0 { - return LogSegmentInfo{}, fmt.Errorf("invalid log segment entry: %q", info.Name) + return LogSegment{}, fmt.Errorf("invalid log segment entry: %q", info.Name) } - segment := LogSegmentInfo{ + segment := LogSegment{ Number: int(n), Size: info.Size, CreatedAt: info.CreatedAt, @@ -221,7 +334,7 @@ func (reg *Registry) ListLogSegments(ctx context.Context, processID format.UUID) } func (reg *Registry) LookupLogManifest(ctx context.Context, processID format.UUID) (*format.Manifest, error) { - r, err := reg.objects.ReadObject(ctx, reg.manifestKey(processID)) + r, err := reg.Store.ReadObject(ctx, reg.manifestKey(processID)) if err != nil { if errors.Is(err, object.ErrNotExist) { err = fmt.Errorf("%w: %s", ErrNoLogRecords, processID) @@ -241,7 +354,7 @@ func (reg *Registry) LookupLogManifest(ctx context.Context, processID format.UUI } func (reg *Registry) ReadLogSegment(ctx context.Context, processID format.UUID, segmentNumber int) (io.ReadCloser, error) { - r, err := reg.objects.ReadObject(ctx, reg.logKey(processID, segmentNumber)) + r, err := reg.Store.ReadObject(ctx, reg.logKey(processID, segmentNumber)) if err != nil { if errors.Is(err, object.ErrNotExist) { err = fmt.Errorf("%w: %s", ErrNoLogRecords, processID) @@ -259,42 +372,5 @@ func (reg *Registry) manifestKey(processID format.UUID) string { } func convert[To, From any](base stream.ReadCloser[From], conv func(From) (To, error)) stream.ReadCloser[To] { - return &convertReadCloser[To, From]{base: base, conv: conv} -} - -type convertReadCloser[To, From any] struct { - base stream.ReadCloser[From] - from []From - conv func(From) (To, error) -} - -func (r *convertReadCloser[To, From]) Close() error { - return r.base.Close() -} - -func (r *convertReadCloser[To, From]) Read(items []To) (n int, err error) { - for n < len(items) { - if i := len(items) - n; cap(r.from) <= i { - r.from = r.from[:i] - } else { - r.from = make([]From, i) - } - - rn, err := r.base.Read(r.from) - - for _, from := range r.from[:rn] { - to, err := r.conv(from) - if err != nil { - r.base.Close() - return n, err - } - items[n] = to - n++ - } - - if err != nil { - return n, err - } - } - return n, nil + return stream.NewReadCloser(stream.ConvertReader[To, From](base, conv), base) } diff --git a/internal/timemachine/registry_test.go b/internal/timemachine/registry_test.go index b9af4cae..ff38764d 100644 --- a/internal/timemachine/registry_test.go +++ b/internal/timemachine/registry_test.go @@ -16,11 +16,13 @@ import ( func TestRegistry(t *testing.T) { t.Run("CreateAndLookup", func(t *testing.T) { - dir, err := object.DirStore(t.TempDir()) + store, err := object.DirStore(t.TempDir()) if err != nil { t.Fatal(err) } - reg := timemachine.NewRegistry(dir) + reg := &timemachine.Registry{ + Store: store, + } testRegistryCreateAndLookup(t, reg, (*timemachine.Registry).CreateModule, @@ -79,7 +81,7 @@ type resource interface { format.ResourceUnmarshaler } -type createMethod[T any] func(*timemachine.Registry, context.Context, T) (*format.Descriptor, error) +type createMethod[T any] func(*timemachine.Registry, context.Context, T, ...object.Tag) (*format.Descriptor, error) type lookupMethod[T any] func(*timemachine.Registry, context.Context, format.Hash) (T, error) @@ -87,14 +89,10 @@ func testRegistryCreateAndLookup[T resource](t *testing.T, reg *timemachine.Regi t.Run(reflect.TypeOf(want).Elem().String(), func(t *testing.T) { ctx := context.Background() - d1, err := create(reg, ctx, want) - assert.OK(t, err) - - d2, err := reg.LookupDescriptor(ctx, d1.Digest) + desc, err := create(reg, ctx, want) assert.OK(t, err) - assert.DeepEqual(t, d1, d2) - got, err := lookup(reg, ctx, d1.Digest) + got, err := lookup(reg, ctx, desc.Digest) assert.OK(t, err) assert.DeepEqual(t, got, want) })