diff --git a/cmd/root/root.go b/cmd/root/root.go index 2cfe136c..0e7e2c3e 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -19,6 +19,7 @@ import ( "github.com/replicate/pget/pkg/config" "github.com/replicate/pget/pkg/download" "github.com/replicate/pget/pkg/logging" + "github.com/replicate/pget/pkg/overrides" ) const rootLongDesc = ` @@ -44,6 +45,7 @@ efficient file extractor, providing a streamlined solution for fetching and unpa var concurrency int var pidFile *cli.PIDFile var chunkSize string +var overridesFile string const chunkSizeDefault = "125M" @@ -170,6 +172,7 @@ func persistentFlags(cmd *cobra.Command) error { cmd.PersistentFlags().Duration(config.OptConnTimeout, 5*time.Second, "Timeout for establishing a connection, format is , e.g. 10s") cmd.PersistentFlags().StringVarP(&chunkSize, config.OptChunkSize, "m", chunkSizeDefault, "Chunk size (in bytes) to use when downloading a file (e.g. 10M)") cmd.PersistentFlags().StringVar(&chunkSize, config.OptMinimumChunkSize, chunkSizeDefault, "Minimum chunk size (in bytes) to use when downloading a file (e.g. 10M)") + cmd.PersistentFlags().StringVar(&overridesFile, config.OptOverridesFile, "", "Override file for routing") cmd.PersistentFlags().BoolP(config.OptForce, "f", false, "OptForce download, overwriting existing file") cmd.PersistentFlags().StringSlice(config.OptResolve, []string{}, "OptResolve hostnames to specific IPs") cmd.PersistentFlags().IntP(config.OptRetries, "r", 5, "Number of retries when attempting to retrieve a file") @@ -273,6 +276,17 @@ func rootExecute(ctx context.Context, urlString, dest string) error { Consumer: consumer, } + if overridesFile != "" { + table, err := overrides.ParseRoutingTable(overridesFile) + if err != nil { + return err + } + options := pget.Options{ + RoutingTable: table, + } + getter.Options = options + } + // TODO DRY this if srvName := config.GetCacheSRV(); srvName != "" { downloadOpts.SliceSize = 500 * humanize.MiByte diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 303c8ad5..5cbb044a 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -22,6 +22,7 @@ const ( OptMaxConcurrentFiles = "max-concurrent-files" OptMinimumChunkSize = "minimum-chunk-size" OptOutputConsumer = "output" + OptOverridesFile = "overrides-file" OptPIDFile = "pid-file" OptResolve = "resolve" OptRetries = "retries" diff --git a/pkg/overrides/routing_table.go b/pkg/overrides/routing_table.go new file mode 100644 index 00000000..add41e08 --- /dev/null +++ b/pkg/overrides/routing_table.go @@ -0,0 +1,31 @@ +package overrides + +import ( + "encoding/json" + "os" +) + +type RoutingTable map[string]string + +type RoutingTableRecord struct { + Key string `json:"key"` + Value string `json:"value"` +} + +func ParseRoutingTable(filepath string) (RoutingTable, error) { + r := make([]RoutingTableRecord, 0) + table := make(RoutingTable) + f, err := os.Open(filepath) + if err != nil { + return table, err + } + defer f.Close() + err = json.NewDecoder(f).Decode(&r) + if err != nil { + return table, err + } + for _, record := range r { + table[record.Key] = record.Value + } + return table, nil +} diff --git a/pkg/pget.go b/pkg/pget.go index 96e7a2bd..f90f9e66 100644 --- a/pkg/pget.go +++ b/pkg/pget.go @@ -12,6 +12,7 @@ import ( "github.com/replicate/pget/pkg/consumer" "github.com/replicate/pget/pkg/download" "github.com/replicate/pget/pkg/logging" + "github.com/replicate/pget/pkg/overrides" ) type Getter struct { @@ -22,6 +23,7 @@ type Getter struct { type Options struct { MaxConcurrentFiles int + RoutingTable overrides.RoutingTable } type ManifestEntry struct { @@ -42,6 +44,9 @@ func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int } logger := logging.GetLogger() downloadStartTime := time.Now() + if override, ok := g.Options.RoutingTable[url]; ok { + url = override + } buffer, fileSize, err := g.Downloader.Fetch(ctx, url) if err != nil { return fileSize, 0, err