Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ func (c *Client) normalizeModelName(model string) string {
return model
}

// Normalize HuggingFace short URL (hf.co) to canonical form (huggingface.co)
// This ensures that hf.co/org/model and huggingface.co/org/model are treated as the same model
if rest, found := strings.CutPrefix(model, "hf.co/"); found {
model = "huggingface.co/" + rest
}

// If it looks like an ID or digest, try to resolve it to full ID
if c.looksLikeID(model) || c.looksLikeDigest(model) {
if fullID := c.resolveID(model); fullID != "" {
Expand Down Expand Up @@ -243,6 +249,24 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
// HuggingFace references always use native pull (download raw files from HF Hub)
if isHuggingFaceReference(originalReference) {
c.log.Infoln("Using native HuggingFace pull for:", utils.SanitizeForLog(reference))

// Check if model already exists in local store (reference is already normalized)
localModel, err := c.store.Read(reference)
if err == nil {
c.log.Infoln("HuggingFace model found in local store:", utils.SanitizeForLog(reference))
cfg, err := localModel.Config()
if err != nil {
return fmt.Errorf("getting cached model config: %w", err)
}
if err := progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.GetSize()), oci.ModePull); err != nil {
c.log.Warnf("Writing progress: %v", err)
}
return nil
}
if !errors.Is(err, ErrModelNotFound) {
return fmt.Errorf("checking for cached HuggingFace model: %w", err)
}

// Pass original reference to preserve case-sensitivity for HuggingFace API
return c.pullNativeHuggingFace(ctx, originalReference, progressWriter, token)
}
Expand Down
53 changes: 53 additions & 0 deletions pkg/distribution/distribution/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1140,3 +1140,56 @@ func randomFile(size int64) (string, error) {

return f.Name(), nil
}

func TestPullHuggingFaceModelFromCache(t *testing.T) {
testCases := []struct {
name string
pullRef string
}{
{
name: "full URL",
pullRef: "huggingface.co/testorg/testmodel:latest",
},
{
name: "short URL",
pullRef: "hf.co/testorg/testmodel:latest",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := t.TempDir()

// Create client
client, err := newTestClient(tempDir)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}

// Create a test model and write it to the store with a normalized HuggingFace tag
model, err := gguf.NewModel(testGGUFFile)
if err != nil {
t.Fatalf("Failed to create model: %v", err)
}

// Store with normalized tag (huggingface.co)
hfTag := "huggingface.co/testorg/testmodel:latest"
if err := client.store.Write(model, []string{hfTag}, nil); err != nil {
t.Fatalf("Failed to write model to store: %v", err)
}

// Now try to pull using the test case's reference - it should use the cache
var progressBuffer bytes.Buffer
err = client.PullModel(t.Context(), tc.pullRef, &progressBuffer)
if err != nil {
t.Fatalf("Failed to pull model from cache: %v", err)
}

// Verify that progress shows it was cached
progressOutput := progressBuffer.String()
if !strings.Contains(progressOutput, "Using cached model") {
t.Errorf("Expected progress to indicate cached model, got: %s", progressOutput)
}
})
}
}