diff --git a/pkg/cmd/kitimport/cmd.go b/pkg/cmd/kitimport/cmd.go index ea100e92..8c9d0d96 100644 --- a/pkg/cmd/kitimport/cmd.go +++ b/pkg/cmd/kitimport/cmd.go @@ -25,6 +25,7 @@ import ( "github.com/kitops-ml/kitops/pkg/lib/constants" "github.com/kitops-ml/kitops/pkg/lib/git" + "github.com/kitops-ml/kitops/pkg/lib/hf" repoutils "github.com/kitops-ml/kitops/pkg/lib/repo/util" "github.com/kitops-ml/kitops/pkg/output" @@ -130,13 +131,19 @@ func (opts *importOptions) complete(ctx context.Context, args []string) error { opts.repo = args[0] if opts.tag == "" { - tag, err := extractRepoFromURL(opts.repo) - if err != nil { - output.Errorf("Could not generate tag from URL: %s", err) - return fmt.Errorf("use flag --tag to set a tag for ModelKit") + var tagRepo string + if repo, _, err := hf.ParseHuggingFaceRepo(opts.repo); err == nil { + tagRepo = repo + } else { + repo, err := extractRepoFromURL(opts.repo) + if err != nil { + output.Errorf("Could not generate tag from URL: %s", err) + return fmt.Errorf("use flag --tag to set a tag for ModelKit") + } + tagRepo = repo } - tag = strings.ToLower(tag) - opts.tag = fmt.Sprintf("%s:latest", tag) + tagRepo = strings.ToLower(tagRepo) + opts.tag = fmt.Sprintf("%s:latest", tagRepo) output.Infof("Using tag %s. Use flag --tag to override", opts.tag) } diff --git a/pkg/cmd/kitimport/hfimport.go b/pkg/cmd/kitimport/hfimport.go index 4ab76d07..18c731c1 100644 --- a/pkg/cmd/kitimport/hfimport.go +++ b/pkg/cmd/kitimport/hfimport.go @@ -37,7 +37,7 @@ import ( func importUsingHF(ctx context.Context, opts *importOptions) error { // Handle full HF URLs by extracting repository name from URL - repo, err := extractRepoFromURL(opts.repo) + repo, repoType, err := hf.ParseHuggingFaceRepo(opts.repo) if err != nil { return fmt.Errorf("could not process URL %s: %w", opts.repo, err) } @@ -53,7 +53,7 @@ func importUsingHF(ctx context.Context, opts *importOptions) error { } }() - dirListing, err := hf.ListFiles(ctx, repo, opts.repoRef, opts.token) + dirListing, err := hf.ListFiles(ctx, repo, opts.repoRef, opts.token, repoType) if err != nil { return fmt.Errorf("failed to list files from HuggingFace API: %w", err) } @@ -106,7 +106,7 @@ func importUsingHF(ctx context.Context, opts *importOptions) error { if err != nil { return err } - if err := hf.DownloadFiles(ctx, repo, opts.repoRef, tmpDir, toDownload, opts.token, opts.concurrency); err != nil { + if err := hf.DownloadFiles(ctx, repo, opts.repoRef, tmpDir, toDownload, opts.token, opts.concurrency, repoType); err != nil { return fmt.Errorf("error downloading repository: %w", err) } diff --git a/pkg/lib/hf/download.go b/pkg/lib/hf/download.go index 430c48bb..cb45ce94 100644 --- a/pkg/lib/hf/download.go +++ b/pkg/lib/hf/download.go @@ -35,7 +35,8 @@ import ( ) const ( - resolveURLFmt = "https://huggingface.co/%s/resolve/%s/%s" + modelResolveURLFmt = "https://huggingface.co/%s/resolve/%s/%s" + datasetResolveURLFmt = "https://huggingface.co/datasets/%s/resolve/%s/%s" ) func DownloadFiles( @@ -43,7 +44,8 @@ func DownloadFiles( modelRepo, repoRef, destDir string, files []kfgen.FileListing, token string, - maxConcurrency int) error { + maxConcurrency int, + repoType RepositoryType) error { client := &http.Client{ Timeout: 1 * time.Hour, @@ -55,6 +57,13 @@ func DownloadFiles( progress, plog := output.NewDownloadProgress() + var resolveURLFmt string + if repoType == RepoTypeDataset { + resolveURLFmt = datasetResolveURLFmt + } else { + resolveURLFmt = modelResolveURLFmt + } + for _, f := range files { if err := sem.Acquire(errCtx, 1); err != nil { semErr = err diff --git a/pkg/lib/hf/list.go b/pkg/lib/hf/list.go index 05a360df..6db186fa 100644 --- a/pkg/lib/hf/list.go +++ b/pkg/lib/hf/list.go @@ -31,7 +31,8 @@ import ( ) const ( - treeURLFmt = "https://huggingface.co/api/models/%s/tree/%s" + modelTreeURLFmt = "https://huggingface.co/api/models/%s/tree/%s" + datasetTreeURLFmt = "https://huggingface.co/api/datasets/%s/tree/%s" ) type hfTreeResponse []struct { @@ -45,11 +46,19 @@ type hfErrorResponse struct { Error string `json:"error"` } -func ListFiles(ctx context.Context, modelRepo, ref string, token string) (*kfgen.DirectoryListing, error) { +func ListFiles(ctx context.Context, modelRepo, ref string, token string, repoType RepositoryType) (*kfgen.DirectoryListing, error) { client := &http.Client{ Timeout: 10 * time.Second, } - baseURL, err := url.Parse(fmt.Sprintf(treeURLFmt, modelRepo, ref)) + + var treeURL string + if repoType == RepoTypeDataset { + treeURL = fmt.Sprintf(datasetTreeURLFmt, modelRepo, ref) + } else { + treeURL = fmt.Sprintf(modelTreeURLFmt, modelRepo, ref) + } + + baseURL, err := url.Parse(treeURL) if err != nil { return nil, fmt.Errorf("failed to parse URL: %w", err) } diff --git a/pkg/lib/hf/repo.go b/pkg/lib/hf/repo.go new file mode 100644 index 00000000..7028d6e8 --- /dev/null +++ b/pkg/lib/hf/repo.go @@ -0,0 +1,61 @@ +// Copyright 2025 The KitOps Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package hf + +import ( + "fmt" + "net/url" + "strings" +) + +// RepositoryType represents the kind of Hugging Face repository. +type RepositoryType int + +const ( + RepoTypeUnknown RepositoryType = iota + RepoTypeModel + RepoTypeDataset +) + +// ParseHuggingFaceRepo parses a Hugging Face repository URL or path and returns +// the normalized repository path (org/repo) and the repository type. +func ParseHuggingFaceRepo(rawURL string) (string, RepositoryType, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", RepoTypeUnknown, fmt.Errorf("failed to parse url: %w", err) + } + + if u.Host != "" && !strings.Contains(u.Host, "huggingface.co") { + return "", RepoTypeUnknown, fmt.Errorf("not a Hugging Face repository") + } + + path := strings.Trim(u.Path, "/") + segments := strings.Split(path, "/") + + if len(segments) >= 3 && segments[0] == "datasets" { + return strings.Join(segments[1:3], "/"), RepoTypeDataset, nil + } + if len(segments) == 2 && segments[0] == "datasets" { + return "", RepoTypeUnknown, fmt.Errorf("could not extract repository from path '%s'", path) + } + + if len(segments) >= 2 { + return strings.Join(segments[len(segments)-2:], "/"), RepoTypeModel, nil + } + + return "", RepoTypeUnknown, fmt.Errorf("could not extract repository from path '%s'", path) +} diff --git a/pkg/lib/hf/repo_test.go b/pkg/lib/hf/repo_test.go new file mode 100644 index 00000000..ffaf807a --- /dev/null +++ b/pkg/lib/hf/repo_test.go @@ -0,0 +1,58 @@ +// Copyright 2025 The KitOps Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package hf + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseHuggingFaceRepo(t *testing.T) { + testcases := []struct { + input string + expectedRepo string + expectedType RepositoryType + expectErrRegexp string + }{ + {input: "https://huggingface.co/org/repo", expectedRepo: "org/repo", expectedType: RepoTypeModel}, + {input: "https://huggingface.co/datasets/org/repo", expectedRepo: "org/repo", expectedType: RepoTypeDataset}, + {input: "org/repo", expectedRepo: "org/repo", expectedType: RepoTypeModel}, + {input: "datasets/org/repo", expectedRepo: "org/repo", expectedType: RepoTypeDataset}, + {input: "datasets/only-one-segment", expectErrRegexp: "could not extract repository"}, + {input: "https://example.com/org/repo", expectErrRegexp: "not a Hugging Face repository"}, + } + + for _, tt := range testcases { + t.Run(fmt.Sprintf("handles %s", tt.input), func(t *testing.T) { + actualRepo, actualType, actualErr := ParseHuggingFaceRepo(tt.input) + if tt.expectErrRegexp != "" { + if !assert.Error(t, actualErr) { + return + } + assert.Regexp(t, tt.expectErrRegexp, actualErr.Error()) + } else { + if !assert.NoError(t, actualErr) { + return + } + assert.Equal(t, tt.expectedRepo, actualRepo) + assert.Equal(t, tt.expectedType, actualType) + } + }) + } +}