Loading flags/flags.go 0 → 100644 +494 −0 Original line number Diff line number Diff line // Copyright (c) 2016 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // Package flags provides an interface for automatically creating command line // options from a struct. // // Typically, if one wants to load from a yaml, one has to define a proper // struct, then yaml.Unmarshal(), this is all good. However, there are // situations where we want to load most of the configs from the file but // overriding some configs. // // Let's say we use a yaml to config our Db connections and upon start of the // application we load from the yaml file to get the necessary parameters to // create the connection. Our base.yaml looks like this // // base.yaml // --- // mysql: // user: 'foo' // password: 'xxxxxx' // mysql_defaults_file: ./mysql_defaults.ini // mysql_socket_path: /var/run/mysqld/mysqld.sock // ... more config options ... // // we want to load all the configs from it but we want to provide some // flexibility for the program to connect via a different db user. We could // define a --user command flag then after loading the yaml file, we override // the user field with what we get from --user flag. // // If there are many overriding like this, manual define these flags is // tedious. This package provides an automatic way to define this override, // which is, given a struct, it'll create all the flags which are name using // the field names of the struct. If one of these flags are set via command // line, the struct will be modified in-place to reflect the value from command // line, therefore the values of the fields in the struct are overridden // // YAML is just used as an example here. In practice, one can use any struct // to define flags. // // Let's say we have our configration object as the following. // // type logging struct { // Interval int // Path string // } // // type socket struct { // ReadTimeout time.Duration // WriteTimeout time.Duration // } // // type tcp struct { // ReadTimeout time.Duration // socket // } // // type network struct { // ReadTimeout time.Duration // WriteTimeout time.Duration // tcp // } // // type Cfg struct { // logging // network // } // // The following code // // func main() { // c := &Cfg{} // flags.ParseArgs(c, os.Args[1:]) // } // // will create the following flags // // -logging.interval int // logging.interval // -logging.path string // logging.path // -network.readtimeout duration // network.readtimeout // -network.tcp.readtimeout duration // network.tcp.readtimeout // -network.tcp.socket.readtimeout duration // network.tcp.socket.readtimeout // -network.tcp.socket.writetimeout duration // network.tcp.socket.writetimeout // -network.writetimeout duration // network.writetimeout // // flags to subcommands are naturally suported. // // func main() { // cmd := os.Args[1] // switch cmd { // case "new" // c1 := &Cfg1{} // ParseArgs(c1, os.Args[2:]) // case "update": // c2 := &Cfg2{} // ParseArgs(c2, os.Args[2:]) // // ... more sub commands ... // } // } // // One can set Flatten to true when calling NewFlagMakerAdv, in which case, // flags are created without namespacing. For example, // // type auth struct { // Token string // Tag float64 // } // // type credentials struct { // User string // Password string // auth // } // // type database struct { // DBName string // TableName string // credentials // } // // type Cfg struct { // logging // database // } // // func main() { // c := &Cfg{} // flags.ParseArgs(c, os.Args[1:]) // } // // will create the following flags // -dbname string // dbname // -interval int // interval // -password string // password // -path string // path // -tablename string // tablename // -tag float // tag // -token string // token // -user string // user // // Please be aware that usual GoLang flag creation rules apply, i.e., if there are // duplication in flag names (in the flattened case it's more likely to happen // unless the caller make due dilligence to create the struct properly), it panics. // // // Note that not all types can have command line flags created for. map, channel // and function type will not defien a flag corresponding to the field. Pointer // types are properly handled and slice type will create multi-value command // line flags. That is, e.g. if a field foo's type is []int, one can use // --foo 10 --foo 15 --foo 20 to override this field value to be // []int{10, 15, 20}. For now, only []int, []string and []float64 are supported // in this fashion. package flags import ( "flag" "fmt" "reflect" "strings" "time" ) // FlagMakingOptions control the way FlagMaker's behavior when defining flags. type FlagMakingOptions struct { // Use lower case flag names rather than the field name/tag name directly. UseLowerCase bool // Create flags in namespaced fashion Flatten bool // If there is a struct tag named 'TagName', use its value as the flag name. // The purpose is that, for yaml/json parsing we often have something like // Foobar string `yaml:"host_name"`, in which case the flag will be named // 'host_name' rather than 'foobar'. TagName string // If there is a struct tag named 'TagUsage', use its value as the usage description. TagUsage string } // FlagMaker enumerate all the exported fields of a struct recursively // and create corresponding command line flags. For anonymous fields, // they are only enumerated if they are pointers to structs. // Usual GoLang flag rules apply, e.g. duplicated flag names leads to // panic. type FlagMaker struct { opts *FlagMakingOptions // We don't consume os.Args directly unless told to. fs *flag.FlagSet } // NewFlagMaker creates a default FlagMaker which creates namespaced flags func NewFlagMaker() *FlagMaker { return NewFlagMakerAdv(&FlagMakingOptions{ UseLowerCase: true, Flatten: false, TagName: "yaml", TagUsage: "usage"}) } // NewFlagMakerAdv gives full control to create flags. func NewFlagMakerAdv(options *FlagMakingOptions) *FlagMaker { return &FlagMaker{ opts: options, fs: flag.NewFlagSet("xFlags", flag.ContinueOnError), } } // NewFlagMakerFlagSet gives full control to create flags. func NewFlagMakerFlagSet(options *FlagMakingOptions, fs *flag.FlagSet) *FlagMaker { return &FlagMaker{ opts: options, fs: fs, } } // ParseArgs parses the string arguments which should not contain the program name. // // obj is the struct to populate. args are the command line arguments, // typically obtained from os.Args. func ParseArgs(obj interface{}, args []string) ([]string, error) { fm := NewFlagMaker() return fm.ParseArgs(obj, args) } // PrintDefaults prints the default value and type of defined flags. // It just calls the standard 'flag' package's PrintDefaults. func (fm *FlagMaker) PrintDefaults() { fm.fs.PrintDefaults() } // ParseArgs parses the arguments based on the FlagMaker's setting. func (fm *FlagMaker) ParseArgs(obj interface{}, args []string) ([]string, error) { v := reflect.ValueOf(obj) if v.Kind() != reflect.Ptr { return args, fmt.Errorf("top level object must be a pointer. %v is passed", v.Type()) } if v.IsNil() { return args, fmt.Errorf("top level object cannot be nil") } switch e := v.Elem(); e.Kind() { case reflect.Struct: fm.enumerateAndCreate("", e, "") case reflect.Interface: if e.Elem().Kind() == reflect.Ptr { fm.enumerateAndCreate("", e, "") } else { return args, fmt.Errorf("interface must have pointer underlying type. %v is passed", v.Type()) } default: return args, fmt.Errorf("object must be a pointer to struct or interface. %v is passed", v.Type()) } err := fm.fs.Parse(args) return fm.fs.Args(), err } func (fm *FlagMaker) enumerateAndCreate(prefix string, value reflect.Value, usage string) { switch value.Kind() { case // do no create flag for these types reflect.Map, reflect.Uintptr, reflect.UnsafePointer, reflect.Array, reflect.Chan, reflect.Func: return case reflect.Slice: // only support slice of strings, ints and float64s switch value.Type().Elem().Kind() { case reflect.String: fm.defineStringSlice(prefix, value, usage) case reflect.Int: fm.defineIntSlice(prefix, value, usage) case reflect.Float64: fm.defineFloat64Slice(prefix, value, usage) } return case // Basic value types reflect.String, reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: fm.defineFlag(prefix, value, usage) return case reflect.Interface: if !value.IsNil() { fm.enumerateAndCreate(prefix, value.Elem(), usage) } return case reflect.Ptr: if value.IsNil() { value.Set(reflect.New(value.Type().Elem())) } fm.enumerateAndCreate(prefix, value.Elem(), usage) return case reflect.Struct: // keep going default: panic(fmt.Sprintf("unknown reflected kind %v", value.Kind())) } numFields := value.NumField() tt := value.Type() for i := 0; i < numFields; i++ { stField := tt.Field(i) // Skip unexported fields, as only exported fields can be set. This is similar to how json and yaml work. if stField.PkgPath != "" && !stField.Anonymous { continue } if stField.Anonymous && fm.getUnderlyingType(stField.Type).Kind() != reflect.Struct { continue } field := value.Field(i) optName := fm.getName(stField) if len(prefix) > 0 && !fm.opts.Flatten { optName = prefix + "." + optName } usageDesc := fm.getUsage(optName, stField) if len(usageDesc) == 0 { optName = optName } fm.enumerateAndCreate(optName, field, usageDesc) } } func (fm *FlagMaker) getName(field reflect.StructField) string { name := field.Tag.Get(fm.opts.TagName) if len(name) == 0 { if field.Anonymous { name = fm.getUnderlyingType(field.Type).Name() } else { name = field.Name } } if fm.opts.UseLowerCase { return strings.ToLower(name) } return name } func (fm *FlagMaker) getUsage(name string, field reflect.StructField) string { usage := field.Tag.Get(fm.opts.TagUsage) if len(usage) == 0 { usage = name } return usage } func (fm *FlagMaker) getUnderlyingType(ttype reflect.Type) reflect.Type { // this only deals with *T unnamed type, other unnamed types, e.g. []int, struct{} // will return empty string. if ttype.Kind() == reflect.Ptr { return fm.getUnderlyingType(ttype.Elem()) } return ttype } // Each object has its type (which prescribes the possible operations/methods // could be invoked; it also has an underlying 'kind', int, float, struct etc. // Since user can freely define types, one 'kind' of object may correpond to // many types. We cannot do type assertion because types of same kind are still // different types. Instead, we convert to the primitive types that corresponds // to the kinds and create flag vars. One thing to know is that, the whole point // of defineFlag() method is to define flag.Vars that points to certain field // of the struct so that command line values can modify the struct. We cannot // define a flag var pointing to arbitrary 'free' varible. // I wish GoLang had macro... var ( stringPtrType = reflect.TypeOf((*string)(nil)) boolPtrType = reflect.TypeOf((*bool)(nil)) float32PtrType = reflect.TypeOf((*float32)(nil)) float64PtrType = reflect.TypeOf((*float64)(nil)) intPtrType = reflect.TypeOf((*int)(nil)) int8PtrType = reflect.TypeOf((*int8)(nil)) int16PtrType = reflect.TypeOf((*int16)(nil)) int32PtrType = reflect.TypeOf((*int32)(nil)) int64PtrType = reflect.TypeOf((*int64)(nil)) uintPtrType = reflect.TypeOf((*uint)(nil)) uint8PtrType = reflect.TypeOf((*uint8)(nil)) uint16PtrType = reflect.TypeOf((*uint16)(nil)) uint32PtrType = reflect.TypeOf((*uint32)(nil)) uint64PtrType = reflect.TypeOf((*uint64)(nil)) ) func (fm *FlagMaker) defineFlag(name string, value reflect.Value, usage string) { // v must be scalar, otherwise panic ptrValue := value.Addr() switch value.Kind() { case reflect.String: v := ptrValue.Convert(stringPtrType).Interface().(*string) fm.fs.StringVar(v, name, value.String(), usage) case reflect.Bool: v := ptrValue.Convert(boolPtrType).Interface().(*bool) fm.fs.BoolVar(v, name, value.Bool(), usage) case reflect.Int: v := ptrValue.Convert(intPtrType).Interface().(*int) fm.fs.IntVar(v, name, int(value.Int()), usage) case reflect.Int8: v := ptrValue.Convert(int8PtrType).Interface().(*int8) fm.fs.Var(newInt8Value(v), name, usage) case reflect.Int16: v := ptrValue.Convert(int16PtrType).Interface().(*int16) fm.fs.Var(newInt16Value(v), name, usage) case reflect.Int32: v := ptrValue.Convert(int32PtrType).Interface().(*int32) fm.fs.Var(newInt32Value(v), name, usage) case reflect.Int64: switch v := ptrValue.Interface().(type) { case *int64: fm.fs.Int64Var(v, name, value.Int(), usage) case *time.Duration: fm.fs.DurationVar(v, name, value.Interface().(time.Duration), usage) default: // (TODO) if one type defines time.Duration, we'll create a int64 flag for it. // Find some acceptible way to deal with it. vv := ptrValue.Convert(int64PtrType).Interface().(*int64) fm.fs.Int64Var(vv, name, value.Int(), usage) } case reflect.Float32: v := ptrValue.Convert(float32PtrType).Interface().(*float32) fm.fs.Var(newFloat32Value(v), name, usage) case reflect.Float64: v := ptrValue.Convert(float64PtrType).Interface().(*float64) fm.fs.Float64Var(v, name, value.Float(), usage) case reflect.Uint: v := ptrValue.Convert(uintPtrType).Interface().(*uint) fm.fs.UintVar(v, name, uint(value.Uint()), usage) case reflect.Uint8: v := ptrValue.Convert(uint8PtrType).Interface().(*uint8) fm.fs.Var(newUint8Value(v), name, usage) case reflect.Uint16: v := ptrValue.Convert(uint16PtrType).Interface().(*uint16) fm.fs.Var(newUint16Value(v), name, usage) case reflect.Uint32: v := ptrValue.Convert(uint32PtrType).Interface().(*uint32) fm.fs.Var(newUint32Value(v), name, usage) case reflect.Uint64: v := ptrValue.Convert(uint64PtrType).Interface().(*uint64) fm.fs.Uint64Var(v, name, value.Uint(), usage) } } func (fm *FlagMaker) defineStringSlice(name string, value reflect.Value, usage string) { ptrValue := value.Addr().Interface().(*[]string) fm.fs.Var(newStringSlice(ptrValue), name, usage) } func (fm *FlagMaker) defineIntSlice(name string, value reflect.Value, usage string) { ptrValue := value.Addr().Interface().(*[]int) fm.fs.Var(newIntSlice(ptrValue), name, usage) } func (fm *FlagMaker) defineFloat64Slice(name string, value reflect.Value, usage string) { ptrValue := value.Addr().Interface().(*[]float64) fm.fs.Var(newFloat64Slice(ptrValue), name, usage) } flags/vars.go 0 → 100644 +246 −0 Original line number Diff line number Diff line // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package flags import ( "fmt" "strconv" ) // additional types type int8Value int8 type int16Value int16 type int32Value int32 type f32Value float32 type uint8Value uint8 type uint32Value uint32 type uint16Value uint16 // Var handlers for each of the types func newInt8Value(p *int8) *int8Value { return (*int8Value)(p) } func newInt16Value(p *int16) *int16Value { return (*int16Value)(p) } func newInt32Value(p *int32) *int32Value { return (*int32Value)(p) } func newFloat32Value(p *float32) *f32Value { return (*f32Value)(p) } func newUint8Value(p *uint8) *uint8Value { return (*uint8Value)(p) } func newUint16Value(p *uint16) *uint16Value { return (*uint16Value)(p) } func newUint32Value(p *uint32) *uint32Value { return (*uint32Value)(p) } // Setters for each of the types func (f *int8Value) Set(s string) error { v, err := strconv.ParseInt(s, 10, 8) if err != nil { return err } *f = int8Value(v) return nil } func (f *int16Value) Set(s string) error { v, err := strconv.ParseInt(s, 10, 16) if err != nil { return err } *f = int16Value(v) return nil } func (f *int32Value) Set(s string) error { v, err := strconv.ParseInt(s, 10, 32) if err != nil { return err } *f = int32Value(v) return nil } func (f *f32Value) Set(s string) error { v, err := strconv.ParseFloat(s, 32) if err != nil { return err } *f = f32Value(v) return nil } func (f *uint8Value) Set(s string) error { v, err := strconv.ParseUint(s, 10, 8) if err != nil { return err } *f = uint8Value(v) return nil } func (f *uint16Value) Set(s string) error { v, err := strconv.ParseUint(s, 10, 16) if err != nil { return err } *f = uint16Value(v) return nil } func (f *uint32Value) Set(s string) error { v, err := strconv.ParseUint(s, 10, 32) if err != nil { return err } *f = uint32Value(v) return nil } // Getters for each of the types func (f *int8Value) Get() interface{} { return int8(*f) } func (f *int16Value) Get() interface{} { return int16(*f) } func (f *int32Value) Get() interface{} { return int32(*f) } func (f *f32Value) Get() interface{} { return float32(*f) } func (f *uint8Value) Get() interface{} { return uint8(*f) } func (f *uint16Value) Get() interface{} { return uint16(*f) } func (f *uint32Value) Get() interface{} { return uint32(*f) } // Stringers for each of the types func (f *int8Value) String() string { return fmt.Sprintf("%v", *f) } func (f *int16Value) String() string { return fmt.Sprintf("%v", *f) } func (f *int32Value) String() string { return fmt.Sprintf("%v", *f) } func (f *f32Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint8Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint16Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint32Value) String() string { return fmt.Sprintf("%v", *f) } // string slice // string slice type strSlice struct { s *[]string set bool // if there a flag defined via command line, the slice will be cleared first. } func newStringSlice(p *[]string) *strSlice { return &strSlice{ s: p, set: false, } } func (s *strSlice) Set(str string) error { if !s.set { *s.s = (*s.s)[:0] s.set = true } *s.s = append(*s.s, str) return nil } func (s *strSlice) Get() interface{} { return []string(*s.s) } func (s *strSlice) String() string { return fmt.Sprintf("%v", s.s) } // int slice type intSlice struct { s *[]int set bool } func newIntSlice(p *[]int) *intSlice { return &intSlice{ s: p, set: false, } } func (is *intSlice) Set(str string) error { i, err := strconv.Atoi(str) if err != nil { return err } if !is.set { *is.s = (*is.s)[:0] is.set = true } *is.s = append(*is.s, i) return nil } func (is *intSlice) Get() interface{} { return []int(*is.s) } func (is *intSlice) String() string { return fmt.Sprintf("%v", is.s) } // float64 slice type float64Slice struct { s *[]float64 set bool } func newFloat64Slice(p *[]float64) *float64Slice { return &float64Slice{ s: p, set: false, } } func (is *float64Slice) Set(str string) error { i, err := strconv.ParseFloat(str, 64) if err != nil { return err } if !is.set { *is.s = (*is.s)[:0] is.set = true } *is.s = append(*is.s, i) return nil } func (is *float64Slice) Get() interface{} { return []float64(*is.s) } func (is *float64Slice) String() string { return fmt.Sprintf("%v", is.s) } Loading
flags/flags.go 0 → 100644 +494 −0 Original line number Diff line number Diff line // Copyright (c) 2016 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // Package flags provides an interface for automatically creating command line // options from a struct. // // Typically, if one wants to load from a yaml, one has to define a proper // struct, then yaml.Unmarshal(), this is all good. However, there are // situations where we want to load most of the configs from the file but // overriding some configs. // // Let's say we use a yaml to config our Db connections and upon start of the // application we load from the yaml file to get the necessary parameters to // create the connection. Our base.yaml looks like this // // base.yaml // --- // mysql: // user: 'foo' // password: 'xxxxxx' // mysql_defaults_file: ./mysql_defaults.ini // mysql_socket_path: /var/run/mysqld/mysqld.sock // ... more config options ... // // we want to load all the configs from it but we want to provide some // flexibility for the program to connect via a different db user. We could // define a --user command flag then after loading the yaml file, we override // the user field with what we get from --user flag. // // If there are many overriding like this, manual define these flags is // tedious. This package provides an automatic way to define this override, // which is, given a struct, it'll create all the flags which are name using // the field names of the struct. If one of these flags are set via command // line, the struct will be modified in-place to reflect the value from command // line, therefore the values of the fields in the struct are overridden // // YAML is just used as an example here. In practice, one can use any struct // to define flags. // // Let's say we have our configration object as the following. // // type logging struct { // Interval int // Path string // } // // type socket struct { // ReadTimeout time.Duration // WriteTimeout time.Duration // } // // type tcp struct { // ReadTimeout time.Duration // socket // } // // type network struct { // ReadTimeout time.Duration // WriteTimeout time.Duration // tcp // } // // type Cfg struct { // logging // network // } // // The following code // // func main() { // c := &Cfg{} // flags.ParseArgs(c, os.Args[1:]) // } // // will create the following flags // // -logging.interval int // logging.interval // -logging.path string // logging.path // -network.readtimeout duration // network.readtimeout // -network.tcp.readtimeout duration // network.tcp.readtimeout // -network.tcp.socket.readtimeout duration // network.tcp.socket.readtimeout // -network.tcp.socket.writetimeout duration // network.tcp.socket.writetimeout // -network.writetimeout duration // network.writetimeout // // flags to subcommands are naturally suported. // // func main() { // cmd := os.Args[1] // switch cmd { // case "new" // c1 := &Cfg1{} // ParseArgs(c1, os.Args[2:]) // case "update": // c2 := &Cfg2{} // ParseArgs(c2, os.Args[2:]) // // ... more sub commands ... // } // } // // One can set Flatten to true when calling NewFlagMakerAdv, in which case, // flags are created without namespacing. For example, // // type auth struct { // Token string // Tag float64 // } // // type credentials struct { // User string // Password string // auth // } // // type database struct { // DBName string // TableName string // credentials // } // // type Cfg struct { // logging // database // } // // func main() { // c := &Cfg{} // flags.ParseArgs(c, os.Args[1:]) // } // // will create the following flags // -dbname string // dbname // -interval int // interval // -password string // password // -path string // path // -tablename string // tablename // -tag float // tag // -token string // token // -user string // user // // Please be aware that usual GoLang flag creation rules apply, i.e., if there are // duplication in flag names (in the flattened case it's more likely to happen // unless the caller make due dilligence to create the struct properly), it panics. // // // Note that not all types can have command line flags created for. map, channel // and function type will not defien a flag corresponding to the field. Pointer // types are properly handled and slice type will create multi-value command // line flags. That is, e.g. if a field foo's type is []int, one can use // --foo 10 --foo 15 --foo 20 to override this field value to be // []int{10, 15, 20}. For now, only []int, []string and []float64 are supported // in this fashion. package flags import ( "flag" "fmt" "reflect" "strings" "time" ) // FlagMakingOptions control the way FlagMaker's behavior when defining flags. type FlagMakingOptions struct { // Use lower case flag names rather than the field name/tag name directly. UseLowerCase bool // Create flags in namespaced fashion Flatten bool // If there is a struct tag named 'TagName', use its value as the flag name. // The purpose is that, for yaml/json parsing we often have something like // Foobar string `yaml:"host_name"`, in which case the flag will be named // 'host_name' rather than 'foobar'. TagName string // If there is a struct tag named 'TagUsage', use its value as the usage description. TagUsage string } // FlagMaker enumerate all the exported fields of a struct recursively // and create corresponding command line flags. For anonymous fields, // they are only enumerated if they are pointers to structs. // Usual GoLang flag rules apply, e.g. duplicated flag names leads to // panic. type FlagMaker struct { opts *FlagMakingOptions // We don't consume os.Args directly unless told to. fs *flag.FlagSet } // NewFlagMaker creates a default FlagMaker which creates namespaced flags func NewFlagMaker() *FlagMaker { return NewFlagMakerAdv(&FlagMakingOptions{ UseLowerCase: true, Flatten: false, TagName: "yaml", TagUsage: "usage"}) } // NewFlagMakerAdv gives full control to create flags. func NewFlagMakerAdv(options *FlagMakingOptions) *FlagMaker { return &FlagMaker{ opts: options, fs: flag.NewFlagSet("xFlags", flag.ContinueOnError), } } // NewFlagMakerFlagSet gives full control to create flags. func NewFlagMakerFlagSet(options *FlagMakingOptions, fs *flag.FlagSet) *FlagMaker { return &FlagMaker{ opts: options, fs: fs, } } // ParseArgs parses the string arguments which should not contain the program name. // // obj is the struct to populate. args are the command line arguments, // typically obtained from os.Args. func ParseArgs(obj interface{}, args []string) ([]string, error) { fm := NewFlagMaker() return fm.ParseArgs(obj, args) } // PrintDefaults prints the default value and type of defined flags. // It just calls the standard 'flag' package's PrintDefaults. func (fm *FlagMaker) PrintDefaults() { fm.fs.PrintDefaults() } // ParseArgs parses the arguments based on the FlagMaker's setting. func (fm *FlagMaker) ParseArgs(obj interface{}, args []string) ([]string, error) { v := reflect.ValueOf(obj) if v.Kind() != reflect.Ptr { return args, fmt.Errorf("top level object must be a pointer. %v is passed", v.Type()) } if v.IsNil() { return args, fmt.Errorf("top level object cannot be nil") } switch e := v.Elem(); e.Kind() { case reflect.Struct: fm.enumerateAndCreate("", e, "") case reflect.Interface: if e.Elem().Kind() == reflect.Ptr { fm.enumerateAndCreate("", e, "") } else { return args, fmt.Errorf("interface must have pointer underlying type. %v is passed", v.Type()) } default: return args, fmt.Errorf("object must be a pointer to struct or interface. %v is passed", v.Type()) } err := fm.fs.Parse(args) return fm.fs.Args(), err } func (fm *FlagMaker) enumerateAndCreate(prefix string, value reflect.Value, usage string) { switch value.Kind() { case // do no create flag for these types reflect.Map, reflect.Uintptr, reflect.UnsafePointer, reflect.Array, reflect.Chan, reflect.Func: return case reflect.Slice: // only support slice of strings, ints and float64s switch value.Type().Elem().Kind() { case reflect.String: fm.defineStringSlice(prefix, value, usage) case reflect.Int: fm.defineIntSlice(prefix, value, usage) case reflect.Float64: fm.defineFloat64Slice(prefix, value, usage) } return case // Basic value types reflect.String, reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: fm.defineFlag(prefix, value, usage) return case reflect.Interface: if !value.IsNil() { fm.enumerateAndCreate(prefix, value.Elem(), usage) } return case reflect.Ptr: if value.IsNil() { value.Set(reflect.New(value.Type().Elem())) } fm.enumerateAndCreate(prefix, value.Elem(), usage) return case reflect.Struct: // keep going default: panic(fmt.Sprintf("unknown reflected kind %v", value.Kind())) } numFields := value.NumField() tt := value.Type() for i := 0; i < numFields; i++ { stField := tt.Field(i) // Skip unexported fields, as only exported fields can be set. This is similar to how json and yaml work. if stField.PkgPath != "" && !stField.Anonymous { continue } if stField.Anonymous && fm.getUnderlyingType(stField.Type).Kind() != reflect.Struct { continue } field := value.Field(i) optName := fm.getName(stField) if len(prefix) > 0 && !fm.opts.Flatten { optName = prefix + "." + optName } usageDesc := fm.getUsage(optName, stField) if len(usageDesc) == 0 { optName = optName } fm.enumerateAndCreate(optName, field, usageDesc) } } func (fm *FlagMaker) getName(field reflect.StructField) string { name := field.Tag.Get(fm.opts.TagName) if len(name) == 0 { if field.Anonymous { name = fm.getUnderlyingType(field.Type).Name() } else { name = field.Name } } if fm.opts.UseLowerCase { return strings.ToLower(name) } return name } func (fm *FlagMaker) getUsage(name string, field reflect.StructField) string { usage := field.Tag.Get(fm.opts.TagUsage) if len(usage) == 0 { usage = name } return usage } func (fm *FlagMaker) getUnderlyingType(ttype reflect.Type) reflect.Type { // this only deals with *T unnamed type, other unnamed types, e.g. []int, struct{} // will return empty string. if ttype.Kind() == reflect.Ptr { return fm.getUnderlyingType(ttype.Elem()) } return ttype } // Each object has its type (which prescribes the possible operations/methods // could be invoked; it also has an underlying 'kind', int, float, struct etc. // Since user can freely define types, one 'kind' of object may correpond to // many types. We cannot do type assertion because types of same kind are still // different types. Instead, we convert to the primitive types that corresponds // to the kinds and create flag vars. One thing to know is that, the whole point // of defineFlag() method is to define flag.Vars that points to certain field // of the struct so that command line values can modify the struct. We cannot // define a flag var pointing to arbitrary 'free' varible. // I wish GoLang had macro... var ( stringPtrType = reflect.TypeOf((*string)(nil)) boolPtrType = reflect.TypeOf((*bool)(nil)) float32PtrType = reflect.TypeOf((*float32)(nil)) float64PtrType = reflect.TypeOf((*float64)(nil)) intPtrType = reflect.TypeOf((*int)(nil)) int8PtrType = reflect.TypeOf((*int8)(nil)) int16PtrType = reflect.TypeOf((*int16)(nil)) int32PtrType = reflect.TypeOf((*int32)(nil)) int64PtrType = reflect.TypeOf((*int64)(nil)) uintPtrType = reflect.TypeOf((*uint)(nil)) uint8PtrType = reflect.TypeOf((*uint8)(nil)) uint16PtrType = reflect.TypeOf((*uint16)(nil)) uint32PtrType = reflect.TypeOf((*uint32)(nil)) uint64PtrType = reflect.TypeOf((*uint64)(nil)) ) func (fm *FlagMaker) defineFlag(name string, value reflect.Value, usage string) { // v must be scalar, otherwise panic ptrValue := value.Addr() switch value.Kind() { case reflect.String: v := ptrValue.Convert(stringPtrType).Interface().(*string) fm.fs.StringVar(v, name, value.String(), usage) case reflect.Bool: v := ptrValue.Convert(boolPtrType).Interface().(*bool) fm.fs.BoolVar(v, name, value.Bool(), usage) case reflect.Int: v := ptrValue.Convert(intPtrType).Interface().(*int) fm.fs.IntVar(v, name, int(value.Int()), usage) case reflect.Int8: v := ptrValue.Convert(int8PtrType).Interface().(*int8) fm.fs.Var(newInt8Value(v), name, usage) case reflect.Int16: v := ptrValue.Convert(int16PtrType).Interface().(*int16) fm.fs.Var(newInt16Value(v), name, usage) case reflect.Int32: v := ptrValue.Convert(int32PtrType).Interface().(*int32) fm.fs.Var(newInt32Value(v), name, usage) case reflect.Int64: switch v := ptrValue.Interface().(type) { case *int64: fm.fs.Int64Var(v, name, value.Int(), usage) case *time.Duration: fm.fs.DurationVar(v, name, value.Interface().(time.Duration), usage) default: // (TODO) if one type defines time.Duration, we'll create a int64 flag for it. // Find some acceptible way to deal with it. vv := ptrValue.Convert(int64PtrType).Interface().(*int64) fm.fs.Int64Var(vv, name, value.Int(), usage) } case reflect.Float32: v := ptrValue.Convert(float32PtrType).Interface().(*float32) fm.fs.Var(newFloat32Value(v), name, usage) case reflect.Float64: v := ptrValue.Convert(float64PtrType).Interface().(*float64) fm.fs.Float64Var(v, name, value.Float(), usage) case reflect.Uint: v := ptrValue.Convert(uintPtrType).Interface().(*uint) fm.fs.UintVar(v, name, uint(value.Uint()), usage) case reflect.Uint8: v := ptrValue.Convert(uint8PtrType).Interface().(*uint8) fm.fs.Var(newUint8Value(v), name, usage) case reflect.Uint16: v := ptrValue.Convert(uint16PtrType).Interface().(*uint16) fm.fs.Var(newUint16Value(v), name, usage) case reflect.Uint32: v := ptrValue.Convert(uint32PtrType).Interface().(*uint32) fm.fs.Var(newUint32Value(v), name, usage) case reflect.Uint64: v := ptrValue.Convert(uint64PtrType).Interface().(*uint64) fm.fs.Uint64Var(v, name, value.Uint(), usage) } } func (fm *FlagMaker) defineStringSlice(name string, value reflect.Value, usage string) { ptrValue := value.Addr().Interface().(*[]string) fm.fs.Var(newStringSlice(ptrValue), name, usage) } func (fm *FlagMaker) defineIntSlice(name string, value reflect.Value, usage string) { ptrValue := value.Addr().Interface().(*[]int) fm.fs.Var(newIntSlice(ptrValue), name, usage) } func (fm *FlagMaker) defineFloat64Slice(name string, value reflect.Value, usage string) { ptrValue := value.Addr().Interface().(*[]float64) fm.fs.Var(newFloat64Slice(ptrValue), name, usage) }
flags/vars.go 0 → 100644 +246 −0 Original line number Diff line number Diff line // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package flags import ( "fmt" "strconv" ) // additional types type int8Value int8 type int16Value int16 type int32Value int32 type f32Value float32 type uint8Value uint8 type uint32Value uint32 type uint16Value uint16 // Var handlers for each of the types func newInt8Value(p *int8) *int8Value { return (*int8Value)(p) } func newInt16Value(p *int16) *int16Value { return (*int16Value)(p) } func newInt32Value(p *int32) *int32Value { return (*int32Value)(p) } func newFloat32Value(p *float32) *f32Value { return (*f32Value)(p) } func newUint8Value(p *uint8) *uint8Value { return (*uint8Value)(p) } func newUint16Value(p *uint16) *uint16Value { return (*uint16Value)(p) } func newUint32Value(p *uint32) *uint32Value { return (*uint32Value)(p) } // Setters for each of the types func (f *int8Value) Set(s string) error { v, err := strconv.ParseInt(s, 10, 8) if err != nil { return err } *f = int8Value(v) return nil } func (f *int16Value) Set(s string) error { v, err := strconv.ParseInt(s, 10, 16) if err != nil { return err } *f = int16Value(v) return nil } func (f *int32Value) Set(s string) error { v, err := strconv.ParseInt(s, 10, 32) if err != nil { return err } *f = int32Value(v) return nil } func (f *f32Value) Set(s string) error { v, err := strconv.ParseFloat(s, 32) if err != nil { return err } *f = f32Value(v) return nil } func (f *uint8Value) Set(s string) error { v, err := strconv.ParseUint(s, 10, 8) if err != nil { return err } *f = uint8Value(v) return nil } func (f *uint16Value) Set(s string) error { v, err := strconv.ParseUint(s, 10, 16) if err != nil { return err } *f = uint16Value(v) return nil } func (f *uint32Value) Set(s string) error { v, err := strconv.ParseUint(s, 10, 32) if err != nil { return err } *f = uint32Value(v) return nil } // Getters for each of the types func (f *int8Value) Get() interface{} { return int8(*f) } func (f *int16Value) Get() interface{} { return int16(*f) } func (f *int32Value) Get() interface{} { return int32(*f) } func (f *f32Value) Get() interface{} { return float32(*f) } func (f *uint8Value) Get() interface{} { return uint8(*f) } func (f *uint16Value) Get() interface{} { return uint16(*f) } func (f *uint32Value) Get() interface{} { return uint32(*f) } // Stringers for each of the types func (f *int8Value) String() string { return fmt.Sprintf("%v", *f) } func (f *int16Value) String() string { return fmt.Sprintf("%v", *f) } func (f *int32Value) String() string { return fmt.Sprintf("%v", *f) } func (f *f32Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint8Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint16Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint32Value) String() string { return fmt.Sprintf("%v", *f) } // string slice // string slice type strSlice struct { s *[]string set bool // if there a flag defined via command line, the slice will be cleared first. } func newStringSlice(p *[]string) *strSlice { return &strSlice{ s: p, set: false, } } func (s *strSlice) Set(str string) error { if !s.set { *s.s = (*s.s)[:0] s.set = true } *s.s = append(*s.s, str) return nil } func (s *strSlice) Get() interface{} { return []string(*s.s) } func (s *strSlice) String() string { return fmt.Sprintf("%v", s.s) } // int slice type intSlice struct { s *[]int set bool } func newIntSlice(p *[]int) *intSlice { return &intSlice{ s: p, set: false, } } func (is *intSlice) Set(str string) error { i, err := strconv.Atoi(str) if err != nil { return err } if !is.set { *is.s = (*is.s)[:0] is.set = true } *is.s = append(*is.s, i) return nil } func (is *intSlice) Get() interface{} { return []int(*is.s) } func (is *intSlice) String() string { return fmt.Sprintf("%v", is.s) } // float64 slice type float64Slice struct { s *[]float64 set bool } func newFloat64Slice(p *[]float64) *float64Slice { return &float64Slice{ s: p, set: false, } } func (is *float64Slice) Set(str string) error { i, err := strconv.ParseFloat(str, 64) if err != nil { return err } if !is.set { *is.s = (*is.s)[:0] is.set = true } *is.s = append(*is.s, i) return nil } func (is *float64Slice) Get() interface{} { return []float64(*is.s) } func (is *float64Slice) String() string { return fmt.Sprintf("%v", is.s) }