Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions pkg/cmd/kitimport/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/kitops-ml/kitops/pkg/lib/constants"
"github.com/kitops-ml/kitops/pkg/lib/git"
"github.com/kitops-ml/kitops/pkg/lib/hf"
repoutils "github.com/kitops-ml/kitops/pkg/lib/repo/util"
"github.com/kitops-ml/kitops/pkg/output"

Expand Down Expand Up @@ -130,13 +131,19 @@ func (opts *importOptions) complete(ctx context.Context, args []string) error {
opts.repo = args[0]

if opts.tag == "" {
tag, err := extractRepoFromURL(opts.repo)
if err != nil {
output.Errorf("Could not generate tag from URL: %s", err)
return fmt.Errorf("use flag --tag to set a tag for ModelKit")
var tagRepo string
if repo, _, err := hf.ParseHuggingFaceRepo(opts.repo); err == nil {
tagRepo = repo
} else {
repo, err := extractRepoFromURL(opts.repo)
if err != nil {
output.Errorf("Could not generate tag from URL: %s", err)
return fmt.Errorf("use flag --tag to set a tag for ModelKit")
}
tagRepo = repo
}
tag = strings.ToLower(tag)
opts.tag = fmt.Sprintf("%s:latest", tag)
tagRepo = strings.ToLower(tagRepo)
opts.tag = fmt.Sprintf("%s:latest", tagRepo)
output.Infof("Using tag %s. Use flag --tag to override", opts.tag)
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/cmd/kitimport/hfimport.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (

func importUsingHF(ctx context.Context, opts *importOptions) error {
// Handle full HF URLs by extracting repository name from URL
repo, err := extractRepoFromURL(opts.repo)
repo, repoType, err := hf.ParseHuggingFaceRepo(opts.repo)
if err != nil {
return fmt.Errorf("could not process URL %s: %w", opts.repo, err)
}
Expand All @@ -53,7 +53,7 @@ func importUsingHF(ctx context.Context, opts *importOptions) error {
}
}()

dirListing, err := hf.ListFiles(ctx, repo, opts.repoRef, opts.token)
dirListing, err := hf.ListFiles(ctx, repo, opts.repoRef, opts.token, repoType)
if err != nil {
return fmt.Errorf("failed to list files from HuggingFace API: %w", err)
}
Expand Down Expand Up @@ -106,7 +106,7 @@ func importUsingHF(ctx context.Context, opts *importOptions) error {
if err != nil {
return err
}
if err := hf.DownloadFiles(ctx, repo, opts.repoRef, tmpDir, toDownload, opts.token, opts.concurrency); err != nil {
if err := hf.DownloadFiles(ctx, repo, opts.repoRef, tmpDir, toDownload, opts.token, opts.concurrency, repoType); err != nil {
return fmt.Errorf("error downloading repository: %w", err)
}

Expand Down
13 changes: 11 additions & 2 deletions pkg/lib/hf/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ import (
)

const (
resolveURLFmt = "https://huggingface.co/%s/resolve/%s/%s"
modelResolveURLFmt = "https://huggingface.co/%s/resolve/%s/%s"
datasetResolveURLFmt = "https://huggingface.co/datasets/%s/resolve/%s/%s"
)

func DownloadFiles(
ctx context.Context,
modelRepo, repoRef, destDir string,
files []kfgen.FileListing,
token string,
maxConcurrency int) error {
maxConcurrency int,
repoType RepositoryType) error {

client := &http.Client{
Timeout: 1 * time.Hour,
Expand All @@ -55,6 +57,13 @@ func DownloadFiles(

progress, plog := output.NewDownloadProgress()

var resolveURLFmt string
if repoType == RepoTypeDataset {
resolveURLFmt = datasetResolveURLFmt
} else {
resolveURLFmt = modelResolveURLFmt
}

for _, f := range files {
if err := sem.Acquire(errCtx, 1); err != nil {
semErr = err
Expand Down
15 changes: 12 additions & 3 deletions pkg/lib/hf/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import (
)

const (
treeURLFmt = "https://huggingface.co/api/models/%s/tree/%s"
modelTreeURLFmt = "https://huggingface.co/api/models/%s/tree/%s"
datasetTreeURLFmt = "https://huggingface.co/api/datasets/%s/tree/%s"
)

type hfTreeResponse []struct {
Expand All @@ -45,11 +46,19 @@ type hfErrorResponse struct {
Error string `json:"error"`
}

func ListFiles(ctx context.Context, modelRepo, ref string, token string) (*kfgen.DirectoryListing, error) {
func ListFiles(ctx context.Context, modelRepo, ref string, token string, repoType RepositoryType) (*kfgen.DirectoryListing, error) {
client := &http.Client{
Timeout: 10 * time.Second,
}
baseURL, err := url.Parse(fmt.Sprintf(treeURLFmt, modelRepo, ref))

var treeURL string
if repoType == RepoTypeDataset {
treeURL = fmt.Sprintf(datasetTreeURLFmt, modelRepo, ref)
} else {
treeURL = fmt.Sprintf(modelTreeURLFmt, modelRepo, ref)
}

baseURL, err := url.Parse(treeURL)
if err != nil {
return nil, fmt.Errorf("failed to parse URL: %w", err)
}
Expand Down
61 changes: 61 additions & 0 deletions pkg/lib/hf/repo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2025 The KitOps Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package hf

import (
"fmt"
"net/url"
"strings"
)

// RepositoryType represents the kind of Hugging Face repository.
type RepositoryType int

const (
RepoTypeUnknown RepositoryType = iota
RepoTypeModel
RepoTypeDataset
)

// ParseHuggingFaceRepo parses a Hugging Face repository URL or path and returns
// the normalized repository path (org/repo) and the repository type.
func ParseHuggingFaceRepo(rawURL string) (string, RepositoryType, error) {
u, err := url.Parse(rawURL)
if err != nil {
return "", RepoTypeUnknown, fmt.Errorf("failed to parse url: %w", err)
}

if u.Host != "" && !strings.Contains(u.Host, "huggingface.co") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this check, inputs like https://huggingface.co.evil.com/org/repo pass the “is HF” check and proceed as if they were Hugging Face. Can you tighten the parsing and validation of the host?

return "", RepoTypeUnknown, fmt.Errorf("not a Hugging Face repository")
}

path := strings.Trim(u.Path, "/")
segments := strings.Split(path, "/")

if len(segments) >= 3 && segments[0] == "datasets" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kit import huggingface.co/datasets/org/repo --tool hf will fail because huggingface.co/datasets/org/repo (no scheme) will be treated as RepoTypeModel because when the scheme is missing segments[0] is the Host (aka huggingface.co)

return strings.Join(segments[1:3], "/"), RepoTypeDataset, nil
}
if len(segments) == 2 && segments[0] == "datasets" {
return "", RepoTypeUnknown, fmt.Errorf("could not extract repository from path '%s'", path)
}

if len(segments) >= 2 {
return strings.Join(segments[len(segments)-2:], "/"), RepoTypeModel, nil
}

return "", RepoTypeUnknown, fmt.Errorf("could not extract repository from path '%s'", path)
}
58 changes: 58 additions & 0 deletions pkg/lib/hf/repo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2025 The KitOps Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package hf

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParseHuggingFaceRepo(t *testing.T) {
testcases := []struct {
input string
expectedRepo string
expectedType RepositoryType
expectErrRegexp string
}{
{input: "https://huggingface.co/org/repo", expectedRepo: "org/repo", expectedType: RepoTypeModel},
{input: "https://huggingface.co/datasets/org/repo", expectedRepo: "org/repo", expectedType: RepoTypeDataset},
{input: "org/repo", expectedRepo: "org/repo", expectedType: RepoTypeModel},
{input: "datasets/org/repo", expectedRepo: "org/repo", expectedType: RepoTypeDataset},
{input: "datasets/only-one-segment", expectErrRegexp: "could not extract repository"},
{input: "https://example.com/org/repo", expectErrRegexp: "not a Hugging Face repository"},
}

for _, tt := range testcases {
t.Run(fmt.Sprintf("handles %s", tt.input), func(t *testing.T) {
actualRepo, actualType, actualErr := ParseHuggingFaceRepo(tt.input)
if tt.expectErrRegexp != "" {
if !assert.Error(t, actualErr) {
return
}
assert.Regexp(t, tt.expectErrRegexp, actualErr.Error())
} else {
if !assert.NoError(t, actualErr) {
return
}
assert.Equal(t, tt.expectedRepo, actualRepo)
assert.Equal(t, tt.expectedType, actualType)
}
})
}
}