diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index f150adef..b5ba001d 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -123,6 +123,12 @@ func (c *Client) normalizeModelName(model string) string { return model } + // Normalize HuggingFace short URL (hf.co) to canonical form (huggingface.co) + // This ensures that hf.co/org/model and huggingface.co/org/model are treated as the same model + if rest, found := strings.CutPrefix(model, "hf.co/"); found { + model = "huggingface.co/" + rest + } + // If it looks like an ID or digest, try to resolve it to full ID if c.looksLikeID(model) || c.looksLikeDigest(model) { if fullID := c.resolveID(model); fullID != "" { @@ -243,6 +249,24 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter // HuggingFace references always use native pull (download raw files from HF Hub) if isHuggingFaceReference(originalReference) { c.log.Infoln("Using native HuggingFace pull for:", utils.SanitizeForLog(reference)) + + // Check if model already exists in local store (reference is already normalized) + localModel, err := c.store.Read(reference) + if err == nil { + c.log.Infoln("HuggingFace model found in local store:", utils.SanitizeForLog(reference)) + cfg, err := localModel.Config() + if err != nil { + return fmt.Errorf("getting cached model config: %w", err) + } + if err := progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.GetSize()), oci.ModePull); err != nil { + c.log.Warnf("Writing progress: %v", err) + } + return nil + } + if !errors.Is(err, ErrModelNotFound) { + return fmt.Errorf("checking for cached HuggingFace model: %w", err) + } + // Pass original reference to preserve case-sensitivity for HuggingFace API return c.pullNativeHuggingFace(ctx, originalReference, progressWriter, token) } diff --git a/pkg/distribution/distribution/client_test.go b/pkg/distribution/distribution/client_test.go index 7e3b65f1..3ed11b5a 100644 --- a/pkg/distribution/distribution/client_test.go +++ b/pkg/distribution/distribution/client_test.go @@ -1140,3 +1140,56 @@ func randomFile(size int64) (string, error) { return f.Name(), nil } + +func TestPullHuggingFaceModelFromCache(t *testing.T) { + testCases := []struct { + name string + pullRef string + }{ + { + name: "full URL", + pullRef: "huggingface.co/testorg/testmodel:latest", + }, + { + name: "short URL", + pullRef: "hf.co/testorg/testmodel:latest", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tempDir := t.TempDir() + + // Create client + client, err := newTestClient(tempDir) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Create a test model and write it to the store with a normalized HuggingFace tag + model, err := gguf.NewModel(testGGUFFile) + if err != nil { + t.Fatalf("Failed to create model: %v", err) + } + + // Store with normalized tag (huggingface.co) + hfTag := "huggingface.co/testorg/testmodel:latest" + if err := client.store.Write(model, []string{hfTag}, nil); err != nil { + t.Fatalf("Failed to write model to store: %v", err) + } + + // Now try to pull using the test case's reference - it should use the cache + var progressBuffer bytes.Buffer + err = client.PullModel(t.Context(), tc.pullRef, &progressBuffer) + if err != nil { + t.Fatalf("Failed to pull model from cache: %v", err) + } + + // Verify that progress shows it was cached + progressOutput := progressBuffer.String() + if !strings.Contains(progressOutput, "Using cached model") { + t.Errorf("Expected progress to indicate cached model, got: %s", progressOutput) + } + }) + } +}