Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions integration-tests/concurrent/concurrent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"path/filepath"
"strings"
"sync"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -219,6 +220,43 @@ func waitForServerReady(serverURL string, timeout time.Duration) bool {
return false
}

// waitForServerStatus polls /health-check until the server reports the given status.
// Unlike waitForServerReady which waits for READY, this can wait for intermediate
// states like STARTING (useful for testing signals during setup).
func waitForServerStatus(serverURL string, targetStatus string, timeout time.Duration) bool {
client := &http.Client{Timeout: 2 * time.Second}
deadline := time.Now().Add(timeout)

for time.Now().Before(deadline) {
resp, err := client.Get(serverURL + "/health-check")
if err != nil {
time.Sleep(200 * time.Millisecond)
continue
}

var health struct {
Status string `json:"status"`
}
if err := json.NewDecoder(resp.Body).Decode(&health); err != nil {
resp.Body.Close()
time.Sleep(200 * time.Millisecond)
continue
}
resp.Body.Close()

if health.Status == targetStatus {
return true
}
if health.Status == "SETUP_FAILED" || health.Status == "DEFUNCT" {
return false
}

time.Sleep(200 * time.Millisecond)
}

return false
}

// allocatePort finds an available TCP port by letting the OS assign one.
func allocatePort() (int, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
Expand Down Expand Up @@ -247,3 +285,225 @@ class Predictor(BasePredictor):
await asyncio.sleep(sleep)
return f"wake up {s}"
`

// TestConcurrentAboveLimit tests that sending more predictions than max_concurrency
// returns a 409 Conflict for the excess prediction.
func TestConcurrentAboveLimit(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test in short mode")
}

tmpDir, err := os.MkdirTemp("", "cog-above-limit-test-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)

if err := os.WriteFile(filepath.Join(tmpDir, "cog.yaml"), []byte(aboveLimitCogYAML), 0o644); err != nil {
t.Fatalf("failed to write cog.yaml: %v", err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "predict.py"), []byte(predictPy), 0o644); err != nil {
t.Fatalf("failed to write predict.py: %v", err)
}

cogBinary, err := harness.ResolveCogBinary()
if err != nil {
t.Fatalf("failed to resolve cog binary: %v", err)
}

imageName := fmt.Sprintf("cog-above-limit-test-%d", time.Now().UnixNano())
defer func() {
exec.Command("docker", "rmi", "-f", imageName).Run()
}()

t.Log("Building image...")
buildCmd := exec.Command(cogBinary, "build", "-t", imageName)
buildCmd.Dir = tmpDir
buildCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1")
if output, err := buildCmd.CombinedOutput(); err != nil {
t.Fatalf("failed to build image: %v\n%s", err, output)
}

t.Log("Starting server...")
port, err := allocatePort()
if err != nil {
t.Fatalf("failed to allocate port: %v", err)
}

serveCmd := exec.Command(cogBinary, "serve", "-p", fmt.Sprintf("%d", port))
serveCmd.Dir = tmpDir
serveCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1")

if err := serveCmd.Start(); err != nil {
t.Fatalf("failed to start server: %v", err)
}
defer func() {
serveCmd.Process.Kill()
serveCmd.Wait()
}()

serverURL := fmt.Sprintf("http://127.0.0.1:%d", port)

t.Log("Waiting for server to be ready...")
if !waitForServerReady(serverURL, 60*time.Second) {
t.Fatal("server did not become ready within timeout")
}

// Fill all 2 slots with long-running predictions (each sleeps 1s)
const maxConcurrency = 2
var wg sync.WaitGroup
for i := range maxConcurrency {
wg.Add(1)
go func(idx int) {
defer wg.Done()
makePrediction(serverURL, idx)
}(i)
}

// Poll with an overflow request until we get a 409, meaning both slots
// are occupied. This avoids a fixed sleep that can flake on slow CI.
t.Log("Polling for 409 (all slots occupied)...")
deadline := time.Now().Add(10 * time.Second)
var resp *http.Response
for time.Now().Before(deadline) {
extraBody := `{"id":"extra","input":{"s":"overflow","sleep":1.0}}`
resp, err = http.Post(
serverURL+"/predictions",
"application/json",
strings.NewReader(extraBody),
)
if err != nil {
t.Fatalf("failed to send extra prediction: %v", err)
}
if resp.StatusCode == http.StatusConflict {
break
}
// Got 200 — slots weren't full yet, close and retry
resp.Body.Close()
time.Sleep(100 * time.Millisecond)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusConflict {
t.Fatalf("extra prediction status = %d, want %d (409 Conflict); slots never filled within timeout", resp.StatusCode, http.StatusConflict)
}

var errResp struct {
Error string `json:"error"`
Status string `json:"status"`
}
if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil {
t.Fatalf("failed to decode error response: %v", err)
}
if errResp.Status != "failed" {
t.Errorf("error response status = %q, want \"failed\"", errResp.Status)
}
if !strings.Contains(strings.ToLower(errResp.Error), "capacity") {
t.Errorf("error response error = %q, want string containing \"capacity\"", errResp.Error)
}

wg.Wait()
}

const aboveLimitCogYAML = `build:
python_version: "3.11"
predict: "predict.py:Predictor"
concurrency:
max: 2
`

// TestSIGTERMDuringSetup tests that SIGTERM during setup() causes clean shutdown.
func TestSIGTERMDuringSetup(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test in short mode")
}

tmpDir, err := os.MkdirTemp("", "cog-sigterm-setup-test-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)

slowSetupCogYAML := `build:
python_version: "3.12"
predict: "predict.py:Predictor"
`
slowSetupPredictPy := `import time
from cog import BasePredictor

class Predictor(BasePredictor):
def setup(self) -> None:
time.sleep(30)

def predict(self, s: str) -> str:
return "hello " + s
`

if err := os.WriteFile(filepath.Join(tmpDir, "cog.yaml"), []byte(slowSetupCogYAML), 0o644); err != nil {
t.Fatalf("failed to write cog.yaml: %v", err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "predict.py"), []byte(slowSetupPredictPy), 0o644); err != nil {
t.Fatalf("failed to write predict.py: %v", err)
}

cogBinary, err := harness.ResolveCogBinary()
if err != nil {
t.Fatalf("failed to resolve cog binary: %v", err)
}

t.Log("Building image...")
imageName := fmt.Sprintf("cog-sigterm-setup-test-%d", time.Now().UnixNano())
defer func() {
exec.Command("docker", "rmi", "-f", imageName).Run()
}()

buildCmd := exec.Command(cogBinary, "build", "-t", imageName)
buildCmd.Dir = tmpDir
buildCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1")
if output, err := buildCmd.CombinedOutput(); err != nil {
t.Fatalf("failed to build image: %v\n%s", err, output)
}

t.Log("Starting server...")
port, err := allocatePort()
if err != nil {
t.Fatalf("failed to allocate port: %v", err)
}

serveCmd := exec.Command(cogBinary, "serve", "-p", fmt.Sprintf("%d", port))
serveCmd.Dir = tmpDir
serveCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1")

if err := serveCmd.Start(); err != nil {
t.Fatalf("failed to start server: %v", err)
}

// Poll health-check until setup has begun (status STARTING),
// rather than a fixed sleep that can be too short on cold Docker pulls.
t.Log("Waiting for setup to begin (STARTING status)...")
if !waitForServerStatus(fmt.Sprintf("http://127.0.0.1:%d", port), "STARTING", 60*time.Second) {
serveCmd.Process.Kill()
serveCmd.Wait()
t.Fatal("server did not reach STARTING status within timeout")
}

// Send SIGTERM
t.Log("Sending SIGTERM during setup...")
if err := serveCmd.Process.Signal(syscall.SIGTERM); err != nil {
t.Fatalf("failed to send signal: %v", err)
}

// Wait for process to exit with a timeout
done := make(chan error, 1)
go func() {
done <- serveCmd.Wait()
}()

select {
case err := <-done:
t.Logf("Server exited: %v", err)
case <-time.After(15 * time.Second):
serveCmd.Process.Kill()
t.Fatal("server did not exit within 15s after SIGTERM")
}
}
33 changes: 23 additions & 10 deletions integration-tests/harness/harness.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ type mockUploadServer struct {

// webhookResult is the summary written to stdout by webhook-server-wait.
type webhookResult struct {
Status string `json:"status"`
OutputSize int `json:"output_size"`
HasError bool `json:"has_error"`
Metrics json.RawMessage `json:"metrics,omitempty"`
Status string `json:"status"`
OutputSize int `json:"output_size"`
HasError bool `json:"has_error"`
ErrorMessage string `json:"error_message,omitempty"`
Metrics json.RawMessage `json:"metrics,omitempty"`
}

// webhookServer accepts prediction webhook callbacks from coglet.
Expand Down Expand Up @@ -574,7 +575,8 @@ func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) {

if neg {
if !statusOK {
// Expected to fail - success!
// Expected to fail — write body to stderr so tests can assert
_, _ = ts.Stderr().Write([]byte(respBody))
return
}
} else {
Expand Down Expand Up @@ -1389,9 +1391,11 @@ func (h *Harness) cmdWebhookServerStart(ts *testscript.TestScript, neg bool, arg

// Stream-parse the JSON to extract status, measure output size, and
// capture metrics without holding the entire output string in memory.
// Output is json.RawMessage because it can be a string (single output)
// or an array (iterator/streaming output).
var payload struct {
Status string `json:"status"`
Output string `json:"output"`
Output json.RawMessage `json:"output"`
Error string `json:"error"`
Metrics json.RawMessage `json:"metrics"`
}
Expand All @@ -1416,11 +1420,20 @@ func (h *Harness) cmdWebhookServerStart(ts *testscript.TestScript, neg bool, arg
if ws.result != nil {
return
}
// Compute output size: for strings, use the unquoted length;
// for arrays or other types, use the raw JSON byte length.
outputSize := len(payload.Output)
var outputStr string
if json.Unmarshal(payload.Output, &outputStr) == nil {
outputSize = len(outputStr)
}

ws.result = &webhookResult{
Status: payload.Status,
OutputSize: len(payload.Output),
HasError: payload.Error != "",
Metrics: payload.Metrics,
Status: payload.Status,
OutputSize: outputSize,
HasError: payload.Error != "",
ErrorMessage: payload.Error,
Metrics: payload.Metrics,
}
close(ws.done)
})
Expand Down
34 changes: 34 additions & 0 deletions integration-tests/tests/async_generator_precollect.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Test that async generator output is pre-collected before response.
#
# Coglet collects all async generator yields into a list before sending
# the response. This test verifies all items arrive in the response and
# that predict_time reflects the full generation duration.

cog serve

curl POST /predictions '{"input":{}}'
stdout '"status":"succeeded"'
# All 5 items should be present in the output
stdout '"output":\["chunk-0","chunk-1","chunk-2","chunk-3","chunk-4"\]'
# predict_time should be a positive number (at least 0.5s for 5 items × 0.1s each)
stdout '"predict_time":[0-9]+\.[0-9]'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"
concurrency:
max: 1

-- predict.py --
import asyncio
from typing import AsyncIterator

from cog import BasePredictor


class Predictor(BasePredictor):
async def predict(self) -> AsyncIterator[str]:
for i in range(5):
await asyncio.sleep(0.1)
yield f"chunk-{i}"
Loading
Loading