diff --git a/.github/actions/aws-test-infra/README.md b/.github/actions/aws-test-infra/README.md new file mode 100644 index 0000000..9f9750e --- /dev/null +++ b/.github/actions/aws-test-infra/README.md @@ -0,0 +1,253 @@ +# AWS Test Infra + +Provisions and tears down AWS test infrastructure (VPC + subnet + IGW + +route table + security group + EC2 instances) for e2e workflows. Replaces +hundreds of lines of duplicated Bash + `aws-cli` with a single tested Go +binary. + +The action is shaped around the existing pattern in `loft-sh/vcluster-pro`: + +- One VPC, one subnet, one IGW, one route table, one security group per + workflow run, all tagged with a caller-supplied **consumer tag** plus a + **RunID**. +- Multiple EC2 instances launched per "role" (typically `primary`, + `worker1`, `worker2`). +- Best-effort teardown by ID, followed by a tag-based **fallback sweep** + that catches anything left behind by a run that failed before exporting + IDs. + +## Authentication + +The binary uses the default `aws-sdk-go-v2` credential chain. Calling +workflows already configured with +`aws-actions/configure-aws-credentials` (OIDC + assume-role) will pass +credentials through automatically. + +## Usage + +### Provision + +```yaml +- name: Set up Go + # Required: this action builds itself from source on every run. + uses: actions/setup-go@v5 + with: + go-version-file: go.mod # or whatever your repo uses + +- name: AWS login (OIDC) + uses: aws-actions/configure-aws-credentials@v5.1.1 + with: + role-to-assume: arn:aws:iam:::role/e2e-test-executor + aws-region: us-west-2 + +- name: Provision e2e infra + id: provision + uses: loft-sh/github-actions/.github/actions/aws-test-infra@aws-test-infra/v1 + with: + command: provision + region: us-west-2 + run-id: selinux-e2e-${{ github.run_id }}-${{ matrix.os }} + consumer-tag: SELinuxE2E=true + sg-name: selinux-e2e-${{ github.run_id }}-${{ matrix.os }} + sg-description: SELinux e2e suite + ami-owner: ${{ matrix.ami_owner }} + ami-filter: ${{ matrix.ami_filter }} + root-device: ${{ matrix.ami_root_device }} + volume-size-gb: '200' + instance-profile: e2e-test-executor + ssm-wait-timeout: 5m + ssm-wait-interval: 10s + ingress-rules: | + -1:-1:-1:10.0.0.0/16 + tcp:8443:8443:0.0.0.0/0 + tcp:30000:32767:0.0.0.0/0 + icmp:-1:-1:10.0.0.0/16 + user-data: | + #!/bin/bash + set -e + dnf install -y https://s3.us-west-2.amazonaws.com/amazon-ssm-us-west-2/latest/linux_amd64/amazon-ssm-agent.rpm + systemctl enable --now amazon-ssm-agent + +- name: Use the infra + env: + PRIMARY_PUBLIC_IP: ${{ steps.provision.outputs.primary-public-ip }} + PRIMARY_INSTANCE_ID: ${{ steps.provision.outputs.primary-instance-id }} + run: ... +``` + +The action populates `outputs.vpc-id`, `outputs.subnet-id`, etc. — see the +inputs/outputs section below for the full list. + +### Cleanup + +Cleanup must run with `if: always()` so that resources are torn down even +when the test run failed. + +```yaml +- name: Cleanup e2e infra + if: always() + uses: loft-sh/github-actions/.github/actions/aws-test-infra@aws-test-infra/v1 + with: + command: cleanup + region: us-west-2 + run-id: selinux-e2e-${{ github.run_id }}-${{ matrix.os }} + vpc-id: ${{ steps.provision.outputs.vpc-id }} + igw-id: ${{ steps.provision.outputs.igw-id }} + subnet-id: ${{ steps.provision.outputs.subnet-id }} + route-table-id: ${{ steps.provision.outputs.route-table-id }} + route-assoc-id: ${{ steps.provision.outputs.route-assoc-id }} + security-group-id: ${{ steps.provision.outputs.security-group-id }} + instance-ids: ${{ steps.provision.outputs.instance-ids }} +``` + +If the provision step failed before producing IDs, leave them blank — the +tag-based sweep will find any orphaned resources by `tag:RunID` and clean +them up. + +### Variable instance count or non-standard role names + +The defaults launch three instances tagged `primary`, `worker1`, `worker2`, +with named outputs (`primary-instance-id`, `worker1-instance-id`, +`worker2-instance-id`) for each. To launch a different count or use +arbitrary role names, override `instance-roles` and read the IDs from the +`instance-id-by-role` JSON output. + +```yaml +- name: Provision (2 instances, custom roles) + id: provision + uses: loft-sh/github-actions/.github/actions/aws-test-infra@aws-test-infra/v1 + with: + command: provision + region: us-west-2 + run-id: my-suite-${{ github.run_id }} + consumer-tag: MySuite=true + sg-name: my-suite-${{ github.run_id }} + ami-owner: '099720109477' + ami-filter: 'ubuntu/images/hvm-ssd*/ubuntu-jammy-22.04-amd64-server-*' + instance-roles: 'controller,agent' + +- name: Use instances + env: + CONTROLLER_ID: ${{ fromJSON(steps.provision.outputs.instance-id-by-role).controller }} + AGENT_ID: ${{ fromJSON(steps.provision.outputs.instance-id-by-role).agent }} + run: ... +``` + +The `instance-ids` output (CSV) and the cleanup wiring continue to work +unchanged for any role count. + +## Ingress rule format + +Each rule is `protocol:fromPort:toPort:cidr`. To pass several rules at +once, put one rule per line in the `ingress-rules` input. + +| Protocol | fromPort | toPort | CIDR | Meaning | +|---|---|---|---|---| +| `-1` | -1 | -1 | `10.0.0.0/16` | All protocols, intra-VPC | +| `tcp` | 8443 | 8443 | `0.0.0.0/0` | vCluster API, wide-open | +| `tcp` | 30000 | 32767 | `1.2.3.4/32` | Inner NodePort range, runner-only | +| `icmp` | -1 | -1 | `10.0.0.0/16` | ICMP intra-VPC | + +## Inputs + + + +| INPUT | TYPE | REQUIRED | DEFAULT | DESCRIPTION | +|--------------------------|--------|----------|-----------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| ami-architecture | string | false | `"x86_64"` | (provision) Architecture filter for AMI lookup
(e.g. x86_64, arm64). Defaults to x86_64 to match
the original Bash workflows; pass an
empty string to disable the filter. | +| ami-filter | string | false | | (provision) AMI name filter for lookup
(latest CreationDate wins) | +| ami-id | string | false | | (provision) Use this exact AMI ID
(skips lookup) | +| ami-owner | string | false | | (provision) AMI owner (account ID or alias) for lookup | +| ami-virtualization-type | string | false | `"hvm"` | (provision) Virtualization-type filter for AMI lookup
(e.g. hvm, paravirtual). Defaults to hvm to match
the original Bash workflows; pass an
empty string to disable the filter. | +| availability-zone | string | false | | (provision) AZ for the subnet (defaults to first AZ in region) | +| command | string | true | | Subcommand: provision or cleanup | +| consumer-tag | string | false | | (provision) Consumer tag in KEY=VALUE form,
e.g. SELinuxE2E=true | +| igw-id | string | false | | (cleanup) Internet gateway ID | +| ingress-rules | string | false | | (provision) Newline-separated ingress rules in protocol:fromPort:toPort:cidr
form. Example: "-1:-1:-1:10.0.0.0/16\ntcp:8443:8443:0.0.0.0/0" | +| instance-ids | string | false | | (cleanup) Comma-separated list of instance IDs | +| instance-profile | string | false | | (provision) IAM instance profile name | +| instance-roles | string | false | `"primary,worker1,worker2"` | (provision) Comma-separated role labels (one instance per role) | +| instance-running-timeout | string | false | `"30m"` | (provision) Max wait for all instances
to reach running state. Bump for
slow-boot edge cases. | +| instance-type | string | false | `"m5.xlarge"` | (provision) EC2 instance type | +| region | string | true | | AWS region | +| root-device | string | false | `"/dev/sda1"` | (provision) Root block-device name, e.g. /dev/sda1
or /dev/xvda | +| route-assoc-id | string | false | | (cleanup) Route table association ID | +| route-table-id | string | false | | (cleanup) Route table ID | +| run-id | string | true | | Unique run identifier; tagged on every
resource as RunID | +| security-group-id | string | false | | (cleanup) Security group ID | +| sg-description | string | false | | (provision) Security group description | +| sg-name | string | false | | (provision) Security group name | +| skip-direct | string | false | `"false"` | (cleanup) Skip direct cleanup; only run
the tag-based sweep | +| skip-ssm-wait | string | false | `"false"` | (provision) Skip waiting for SSM agents | +| skip-sweep | string | false | `"false"` | (cleanup) Skip the tag-based sweep; only
run direct cleanup with the supplied
IDs | +| ssm-wait-interval | string | false | `"10s"` | (provision) Polling interval for SSM agent
registration | +| ssm-wait-timeout | string | false | `"5m"` | (provision) How long to wait for
all SSM agents to register | +| strict-sweep | string | false | `"false"` | (cleanup) Fail the cleanup step on
sweep errors. Default false matches the
original Bash teardown (set +e). Set true
if you would rather see sweep
failures than silently leak resources on
AWS API hiccups. | +| subnet-cidr | string | false | `"10.0.1.0/24"` | (provision) Subnet CIDR | +| subnet-id | string | false | | (cleanup) Subnet ID | +| user-data | string | false | | (provision) Raw user-data content; written to
a temp file and base64-encoded by
the binary | +| volume-size-gb | string | false | `"100"` | (provision) Root volume size in GB | +| vpc-cidr | string | false | `"10.0.0.0/16"` | (provision) VPC CIDR | +| vpc-id | string | false | | (cleanup) VPC ID | + + + +## Outputs + + + +| OUTPUT | TYPE | DESCRIPTION | +|---------------------|--------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| ami-id | string | Resolved AMI ID | +| igw-id | string | Created internet gateway ID | +| instance-id-by-role | string | JSON map of role → instance
ID. Use for arbitrary role names
(anything other than primary/worker1/worker2). Consumer accesses with `fromJSON(steps..outputs.instance-id-by-role).`. | +| instance-ids | string | Comma-separated list of all instance IDs | +| primary-instance-id | string | Instance ID of the role labeled
"primary" (empty if not present) | +| primary-public-ip | string | Public IP of the primary instance | +| route-assoc-id | string | Created route table association ID | +| route-table-id | string | Created route table ID | +| security-group-id | string | Created security group ID | +| subnet-id | string | Created subnet ID | +| vpc-id | string | Created VPC ID | +| worker1-instance-id | string | Instance ID of the role labeled
"worker1" (empty if not present) | +| worker2-instance-id | string | Instance ID of the role labeled
"worker2" (empty if not present) | + + + +## Local development + +The Go source lives at `src/`. Build and test locally: + +```sh +cd src +go test ./... +go build -o /tmp/aws-test-infra . +/tmp/aws-test-infra provision -h +/tmp/aws-test-infra cleanup -h +``` + +## How it works (build-from-source) + +The action builds the Go binary at runtime from `src/` and runs it. There +is no separate release artifact — the consumer references a tag (e.g. +`@aws-test-infra/v1`), GitHub fetches the action source at that ref, and +the action builds + invokes it in the same job. + +This requires the **consumer's runner already has Go installed** (e.g. +via a prior `actions/setup-go` step). Both current consumers +(`vcluster-pro` selinux + prerelease workflows) do; if a future consumer +doesn't, the action emits a clear error. + +## Releasing + +Tag scheme is `aws-test-infra/v*`, e.g. `aws-test-infra/v1`. Push the tag +at the merged commit on `main`: + +```sh +git tag aws-test-infra/v1 +git push origin aws-test-infra/v1 +``` + +That's it — no release workflow, no binary upload, no SHA-256 dance. +Consumers can use `@aws-test-infra/v1` immediately. Force-pushing the +tag works the same way (next consumer run picks up the new code). diff --git a/.github/actions/aws-test-infra/action.yml b/.github/actions/aws-test-infra/action.yml new file mode 100644 index 0000000..73a4968 --- /dev/null +++ b/.github/actions/aws-test-infra/action.yml @@ -0,0 +1,317 @@ +name: 'AWS Test Infra' +description: 'Provision or tear down AWS test infrastructure (VPC + subnet + IGW + route table + security group + EC2 instances) for e2e workflows. Replaces ~150 lines of duplicated Bash + aws-cli with a tested Go binary.' + +inputs: + command: + description: 'Subcommand: provision or cleanup' + required: true + region: + description: 'AWS region' + required: true + run-id: + description: 'Unique run identifier; tagged on every resource as RunID' + required: true + + # Provision-only inputs ──────────────────────────────────────────────── + consumer-tag: + description: '(provision) Consumer tag in KEY=VALUE form, e.g. SELinuxE2E=true' + required: false + default: '' + vpc-cidr: + description: '(provision) VPC CIDR' + required: false + default: '10.0.0.0/16' + subnet-cidr: + description: '(provision) Subnet CIDR' + required: false + default: '10.0.1.0/24' + availability-zone: + description: '(provision) AZ for the subnet (defaults to first AZ in region)' + required: false + default: '' + ami-id: + description: '(provision) Use this exact AMI ID (skips lookup)' + required: false + default: '' + ami-owner: + description: '(provision) AMI owner (account ID or alias) for lookup' + required: false + default: '' + ami-filter: + description: '(provision) AMI name filter for lookup (latest CreationDate wins)' + required: false + default: '' + ami-architecture: + description: '(provision) Architecture filter for AMI lookup (e.g. x86_64, arm64). Defaults to x86_64 to match the original Bash workflows; pass an empty string to disable the filter.' + required: false + default: 'x86_64' + ami-virtualization-type: + description: '(provision) Virtualization-type filter for AMI lookup (e.g. hvm, paravirtual). Defaults to hvm to match the original Bash workflows; pass an empty string to disable the filter.' + required: false + default: 'hvm' + sg-name: + description: '(provision) Security group name' + required: false + default: '' + sg-description: + description: '(provision) Security group description' + required: false + default: '' + ingress-rules: + description: '(provision) Newline-separated ingress rules in protocol:fromPort:toPort:cidr form. Example: "-1:-1:-1:10.0.0.0/16\ntcp:8443:8443:0.0.0.0/0"' + required: false + default: '' + instance-type: + description: '(provision) EC2 instance type' + required: false + default: 'm5.xlarge' + instance-profile: + description: '(provision) IAM instance profile name' + required: false + default: '' + instance-roles: + description: '(provision) Comma-separated role labels (one instance per role)' + required: false + default: 'primary,worker1,worker2' + root-device: + description: '(provision) Root block-device name, e.g. /dev/sda1 or /dev/xvda' + required: false + default: '/dev/sda1' + volume-size-gb: + description: '(provision) Root volume size in GB' + required: false + default: '100' + user-data: + description: '(provision) Raw user-data content; written to a temp file and base64-encoded by the binary' + required: false + default: '' + ssm-wait-timeout: + description: '(provision) How long to wait for all SSM agents to register' + required: false + default: '5m' + ssm-wait-interval: + description: '(provision) Polling interval for SSM agent registration' + required: false + default: '10s' + skip-ssm-wait: + description: '(provision) Skip waiting for SSM agents' + required: false + default: 'false' + instance-running-timeout: + description: '(provision) Max wait for all instances to reach running state. Bump for slow-boot edge cases.' + required: false + default: '30m' + + # Cleanup-only inputs ───────────────────────────────────────────────── + vpc-id: + description: '(cleanup) VPC ID' + required: false + default: '' + igw-id: + description: '(cleanup) Internet gateway ID' + required: false + default: '' + subnet-id: + description: '(cleanup) Subnet ID' + required: false + default: '' + route-table-id: + description: '(cleanup) Route table ID' + required: false + default: '' + route-assoc-id: + description: '(cleanup) Route table association ID' + required: false + default: '' + security-group-id: + description: '(cleanup) Security group ID' + required: false + default: '' + instance-ids: + description: '(cleanup) Comma-separated list of instance IDs' + required: false + default: '' + skip-direct: + description: '(cleanup) Skip direct cleanup; only run the tag-based sweep' + required: false + default: 'false' + skip-sweep: + description: '(cleanup) Skip the tag-based sweep; only run direct cleanup with the supplied IDs' + required: false + default: 'false' + strict-sweep: + description: '(cleanup) Fail the cleanup step on sweep errors. Default false matches the original Bash teardown (set +e). Set true if you would rather see sweep failures than silently leak resources on AWS API hiccups.' + required: false + default: 'false' + +outputs: + vpc-id: + description: 'Created VPC ID' + value: ${{ steps.run.outputs.vpc_id }} + igw-id: + description: 'Created internet gateway ID' + value: ${{ steps.run.outputs.igw_id }} + subnet-id: + description: 'Created subnet ID' + value: ${{ steps.run.outputs.subnet_id }} + route-table-id: + description: 'Created route table ID' + value: ${{ steps.run.outputs.route_table_id }} + route-assoc-id: + description: 'Created route table association ID' + value: ${{ steps.run.outputs.route_assoc_id }} + security-group-id: + description: 'Created security group ID' + value: ${{ steps.run.outputs.security_group_id }} + ami-id: + description: 'Resolved AMI ID' + value: ${{ steps.run.outputs.ami_id }} + primary-public-ip: + description: 'Public IP of the primary instance' + value: ${{ steps.run.outputs.primary_public_ip }} + instance-ids: + description: 'Comma-separated list of all instance IDs' + value: ${{ steps.run.outputs.instance_ids }} + primary-instance-id: + description: 'Instance ID of the role labeled "primary" (empty if not present)' + value: ${{ steps.run.outputs.instance_id_primary }} + worker1-instance-id: + description: 'Instance ID of the role labeled "worker1" (empty if not present)' + value: ${{ steps.run.outputs.instance_id_worker1 }} + worker2-instance-id: + description: 'Instance ID of the role labeled "worker2" (empty if not present)' + value: ${{ steps.run.outputs.instance_id_worker2 }} + instance-id-by-role: + description: 'JSON map of role → instance ID. Use for arbitrary role names (anything other than primary/worker1/worker2). Consumer accesses with `fromJSON(steps..outputs.instance-id-by-role).`.' + value: ${{ steps.run.outputs.instance_id_by_role }} + +runs: + using: 'composite' + steps: + - name: Build aws-test-infra + id: build + shell: bash + working-directory: ${{ github.action_path }}/src + run: | + # The action is build-from-source on every run; we mirror run-ginkgo's + # pattern of assuming Go is already installed in the consumer + # workflow. Both vcluster-pro consumers (selinux + prerel) call + # actions/setup-go before this action, so the toolchain is in place. + # If a future consumer doesn't already have Go, they need to add + # setup-go before referencing this action. + if ! command -v go >/dev/null 2>&1; then + echo "::error::Go is not installed in the runner. Add 'uses: actions/setup-go@v5' before this action." + exit 1 + fi + BINARY="$(mktemp -d)/aws-test-infra" + CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o "$BINARY" . + echo "binary=$BINARY" >> "$GITHUB_OUTPUT" + + - name: Stage user-data file + id: userdata + if: inputs.command == 'provision' && inputs.user-data != '' + shell: bash + env: + USER_DATA: ${{ inputs.user-data }} + run: | + UD_PATH="$(mktemp)" + printf '%s' "$USER_DATA" > "$UD_PATH" + echo "path=$UD_PATH" >> "$GITHUB_OUTPUT" + + - name: Run aws-test-infra + id: run + shell: bash + env: + BINARY: ${{ steps.build.outputs.binary }} + INPUT_CMD: ${{ inputs.command }} + INPUT_REGION: ${{ inputs.region }} + INPUT_RUN_ID: ${{ inputs.run-id }} + # Provision + INPUT_CONSUMER_TAG: ${{ inputs.consumer-tag }} + INPUT_VPC_CIDR: ${{ inputs.vpc-cidr }} + INPUT_SUBNET_CIDR: ${{ inputs.subnet-cidr }} + INPUT_AVAILABILITY_ZONE: ${{ inputs.availability-zone }} + INPUT_AMI_ID: ${{ inputs.ami-id }} + INPUT_AMI_OWNER: ${{ inputs.ami-owner }} + INPUT_AMI_FILTER: ${{ inputs.ami-filter }} + INPUT_AMI_ARCHITECTURE: ${{ inputs.ami-architecture }} + INPUT_AMI_VIRTUALIZATION_TYPE: ${{ inputs.ami-virtualization-type }} + INPUT_SG_NAME: ${{ inputs.sg-name }} + INPUT_SG_DESCRIPTION: ${{ inputs.sg-description }} + INPUT_INGRESS_RULES: ${{ inputs.ingress-rules }} + INPUT_INSTANCE_TYPE: ${{ inputs.instance-type }} + INPUT_INSTANCE_PROFILE: ${{ inputs.instance-profile }} + INPUT_INSTANCE_ROLES: ${{ inputs.instance-roles }} + INPUT_ROOT_DEVICE: ${{ inputs.root-device }} + INPUT_VOLUME_SIZE_GB: ${{ inputs.volume-size-gb }} + INPUT_USER_DATA_FILE: ${{ steps.userdata.outputs.path }} + INPUT_SSM_WAIT_TIMEOUT: ${{ inputs.ssm-wait-timeout }} + INPUT_SSM_WAIT_INTERVAL: ${{ inputs.ssm-wait-interval }} + INPUT_SKIP_SSM_WAIT: ${{ inputs.skip-ssm-wait }} + INPUT_INSTANCE_RUNNING_TIMEOUT: ${{ inputs.instance-running-timeout }} + # Cleanup + INPUT_VPC_ID: ${{ inputs.vpc-id }} + INPUT_IGW_ID: ${{ inputs.igw-id }} + INPUT_SUBNET_ID: ${{ inputs.subnet-id }} + INPUT_ROUTE_TABLE_ID: ${{ inputs.route-table-id }} + INPUT_ROUTE_ASSOC_ID: ${{ inputs.route-assoc-id }} + INPUT_SECURITY_GROUP_ID: ${{ inputs.security-group-id }} + INPUT_INSTANCE_IDS: ${{ inputs.instance-ids }} + INPUT_SKIP_DIRECT: ${{ inputs.skip-direct }} + INPUT_SKIP_SWEEP: ${{ inputs.skip-sweep }} + INPUT_STRICT_SWEEP: ${{ inputs.strict-sweep }} + run: | + set -euo pipefail + ARGS=(-region="$INPUT_REGION" -run-id="$INPUT_RUN_ID") + + case "$INPUT_CMD" in + provision) + [ -n "$INPUT_CONSUMER_TAG" ] && ARGS+=(-consumer-tag="$INPUT_CONSUMER_TAG") + ARGS+=(-vpc-cidr="$INPUT_VPC_CIDR" -subnet-cidr="$INPUT_SUBNET_CIDR") + [ -n "$INPUT_AVAILABILITY_ZONE" ] && ARGS+=(-availability-zone="$INPUT_AVAILABILITY_ZONE") + [ -n "$INPUT_AMI_ID" ] && ARGS+=(-ami-id="$INPUT_AMI_ID") + [ -n "$INPUT_AMI_OWNER" ] && ARGS+=(-ami-owner="$INPUT_AMI_OWNER") + [ -n "$INPUT_AMI_FILTER" ] && ARGS+=(-ami-filter="$INPUT_AMI_FILTER") + [ -n "$INPUT_AMI_ARCHITECTURE" ] && ARGS+=(-ami-architecture="$INPUT_AMI_ARCHITECTURE") + [ -n "$INPUT_AMI_VIRTUALIZATION_TYPE" ] && ARGS+=(-ami-virtualization-type="$INPUT_AMI_VIRTUALIZATION_TYPE") + [ -n "$INPUT_SG_NAME" ] && ARGS+=(-sg-name="$INPUT_SG_NAME") + [ -n "$INPUT_SG_DESCRIPTION" ] && ARGS+=(-sg-description="$INPUT_SG_DESCRIPTION") + ARGS+=(-instance-type="$INPUT_INSTANCE_TYPE" -instance-roles="$INPUT_INSTANCE_ROLES" -root-device="$INPUT_ROOT_DEVICE" -volume-size-gb="$INPUT_VOLUME_SIZE_GB") + [ -n "$INPUT_INSTANCE_PROFILE" ] && ARGS+=(-instance-profile="$INPUT_INSTANCE_PROFILE") + [ -n "$INPUT_USER_DATA_FILE" ] && ARGS+=(-user-data-file="$INPUT_USER_DATA_FILE") + ARGS+=(-ssm-wait-timeout="$INPUT_SSM_WAIT_TIMEOUT" -ssm-wait-interval="$INPUT_SSM_WAIT_INTERVAL") + [ "$INPUT_SKIP_SSM_WAIT" = "true" ] && ARGS+=(-skip-ssm-wait) + [ -n "$INPUT_INSTANCE_RUNNING_TIMEOUT" ] && ARGS+=(-instance-running-timeout="$INPUT_INSTANCE_RUNNING_TIMEOUT") + # Each line in INPUT_INGRESS_RULES becomes a -ingress flag. + if [ -n "$INPUT_INGRESS_RULES" ]; then + while IFS= read -r line; do + line="$(echo "$line" | xargs)" # trim + [ -z "$line" ] && continue + ARGS+=(-ingress="$line") + done <<< "$INPUT_INGRESS_RULES" + fi + ARGS+=(-output="$GITHUB_OUTPUT" -output-format=github-output) + "$BINARY" provision "${ARGS[@]}" + ;; + cleanup) + [ -n "$INPUT_VPC_ID" ] && ARGS+=(-vpc-id="$INPUT_VPC_ID") + [ -n "$INPUT_IGW_ID" ] && ARGS+=(-igw-id="$INPUT_IGW_ID") + [ -n "$INPUT_SUBNET_ID" ] && ARGS+=(-subnet-id="$INPUT_SUBNET_ID") + [ -n "$INPUT_ROUTE_TABLE_ID" ] && ARGS+=(-route-table-id="$INPUT_ROUTE_TABLE_ID") + [ -n "$INPUT_ROUTE_ASSOC_ID" ] && ARGS+=(-route-assoc-id="$INPUT_ROUTE_ASSOC_ID") + [ -n "$INPUT_SECURITY_GROUP_ID" ] && ARGS+=(-security-group-id="$INPUT_SECURITY_GROUP_ID") + [ -n "$INPUT_INSTANCE_IDS" ] && ARGS+=(-instance-ids="$INPUT_INSTANCE_IDS") + [ "$INPUT_SKIP_DIRECT" = "true" ] && ARGS+=(-skip-direct) + [ "$INPUT_SKIP_SWEEP" = "true" ] && ARGS+=(-skip-sweep) + [ "$INPUT_STRICT_SWEEP" = "true" ] && ARGS+=(-strict-sweep) + "$BINARY" cleanup "${ARGS[@]}" + ;; + *) + echo "::error::Unknown command: $INPUT_CMD (must be 'provision' or 'cleanup')" + exit 1 + ;; + esac + +branding: + icon: 'cloud' + color: 'orange' diff --git a/.github/actions/aws-test-infra/src/.gitignore b/.github/actions/aws-test-infra/src/.gitignore new file mode 100644 index 0000000..b99712f --- /dev/null +++ b/.github/actions/aws-test-infra/src/.gitignore @@ -0,0 +1,3 @@ +# Local build artifact from `go build .` — the action builds at runtime, +# so this binary should never be committed. +aws-test-infra diff --git a/.github/actions/aws-test-infra/src/awsclient.go b/.github/actions/aws-test-infra/src/awsclient.go new file mode 100644 index 0000000..575df32 --- /dev/null +++ b/.github/actions/aws-test-infra/src/awsclient.go @@ -0,0 +1,104 @@ +package main + +import ( + "context" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ssm" +) + +// EC2API is the subset of the EC2 client we use. Defined as an interface so +// tests can supply a fake without spinning up LocalStack or running real AWS. +type EC2API interface { + CreateVpc(ctx context.Context, params *ec2.CreateVpcInput, optFns ...func(*ec2.Options)) (*ec2.CreateVpcOutput, error) + ModifyVpcAttribute(ctx context.Context, params *ec2.ModifyVpcAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyVpcAttributeOutput, error) + DeleteVpc(ctx context.Context, params *ec2.DeleteVpcInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVpcOutput, error) + DescribeVpcs(ctx context.Context, params *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) + + CreateInternetGateway(ctx context.Context, params *ec2.CreateInternetGatewayInput, optFns ...func(*ec2.Options)) (*ec2.CreateInternetGatewayOutput, error) + AttachInternetGateway(ctx context.Context, params *ec2.AttachInternetGatewayInput, optFns ...func(*ec2.Options)) (*ec2.AttachInternetGatewayOutput, error) + DetachInternetGateway(ctx context.Context, params *ec2.DetachInternetGatewayInput, optFns ...func(*ec2.Options)) (*ec2.DetachInternetGatewayOutput, error) + DeleteInternetGateway(ctx context.Context, params *ec2.DeleteInternetGatewayInput, optFns ...func(*ec2.Options)) (*ec2.DeleteInternetGatewayOutput, error) + DescribeInternetGateways(ctx context.Context, params *ec2.DescribeInternetGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInternetGatewaysOutput, error) + + DescribeAvailabilityZones(ctx context.Context, params *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) + CreateSubnet(ctx context.Context, params *ec2.CreateSubnetInput, optFns ...func(*ec2.Options)) (*ec2.CreateSubnetOutput, error) + ModifySubnetAttribute(ctx context.Context, params *ec2.ModifySubnetAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifySubnetAttributeOutput, error) + DeleteSubnet(ctx context.Context, params *ec2.DeleteSubnetInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSubnetOutput, error) + DescribeSubnets(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) + + CreateRouteTable(ctx context.Context, params *ec2.CreateRouteTableInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteTableOutput, error) + CreateRoute(ctx context.Context, params *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) + AssociateRouteTable(ctx context.Context, params *ec2.AssociateRouteTableInput, optFns ...func(*ec2.Options)) (*ec2.AssociateRouteTableOutput, error) + DisassociateRouteTable(ctx context.Context, params *ec2.DisassociateRouteTableInput, optFns ...func(*ec2.Options)) (*ec2.DisassociateRouteTableOutput, error) + DeleteRouteTable(ctx context.Context, params *ec2.DeleteRouteTableInput, optFns ...func(*ec2.Options)) (*ec2.DeleteRouteTableOutput, error) + DescribeRouteTables(ctx context.Context, params *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) + + CreateSecurityGroup(ctx context.Context, params *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) + AuthorizeSecurityGroupIngress(ctx context.Context, params *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) + DeleteSecurityGroup(ctx context.Context, params *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) + DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) + + DescribeImages(ctx context.Context, params *ec2.DescribeImagesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeImagesOutput, error) + RunInstances(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) + TerminateInstances(ctx context.Context, params *ec2.TerminateInstancesInput, optFns ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) + DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) +} + +// SSMAPI is the subset of the SSM client we use. +type SSMAPI interface { + DescribeInstanceInformation(ctx context.Context, params *ssm.DescribeInstanceInformationInput, optFns ...func(*ssm.Options)) (*ssm.DescribeInstanceInformationOutput, error) +} + +// EC2Waiter is the subset of EC2 waiters we use, abstracted so tests can +// short-circuit them. +type EC2Waiter interface { + WaitInstanceRunning(ctx context.Context, ids []string) error + WaitInstanceTerminated(ctx context.Context, ids []string) error +} + +// ec2WaiterAdapter adapts the SDK's typed waiters to our interface. +// +// The instance-running max-wait is configurable so workflows that hit +// slow-boot edge cases (rare AWS capacity issues, large image pulls in +// user-data) can extend it. The SDK's default for instance-running is 10 +// minutes (40 attempts × 15s); we cap higher to give a meaningful safety +// margin while still failing rather than hanging indefinitely. +// +// Termination is unconditionally fast in practice, so terminated wait +// stays at the package default. +type ec2WaiterAdapter struct { + client *ec2.Client + instanceRunningTimeout time.Duration +} + +// effectiveInstanceRunningTimeout resolves the configured timeout to a +// usable value, falling back to the package default when the field is +// unset or non-positive. Extracted so the wiring is unit-testable +// independently of the SDK waiter (which can't be cheaply mocked). +func (a *ec2WaiterAdapter) effectiveInstanceRunningTimeout() time.Duration { + if a.instanceRunningTimeout <= 0 { + return defaultWaiterMaxWait + } + return a.instanceRunningTimeout +} + +func (a *ec2WaiterAdapter) WaitInstanceRunning(ctx context.Context, ids []string) error { + w := ec2.NewInstanceRunningWaiter(a.client) + return w.Wait(ctx, &ec2.DescribeInstancesInput{InstanceIds: ids}, a.effectiveInstanceRunningTimeout()) +} + +func (a *ec2WaiterAdapter) WaitInstanceTerminated(ctx context.Context, ids []string) error { + w := ec2.NewInstanceTerminatedWaiter(a.client) + return w.Wait(ctx, &ec2.DescribeInstancesInput{InstanceIds: ids}, defaultWaiterMaxWait) +} + +// loadAWSConfig builds a default-chain aws.Config bound to the given region. +// Workflows that already used aws-actions/configure-aws-credentials will have +// AWS_* env vars set; this picks them up automatically. +func loadAWSConfig(ctx context.Context, region string) (aws.Config, error) { + return config.LoadDefaultConfig(ctx, config.WithRegion(region)) +} diff --git a/.github/actions/aws-test-infra/src/awsclient_test.go b/.github/actions/aws-test-infra/src/awsclient_test.go new file mode 100644 index 0000000..540e2bf --- /dev/null +++ b/.github/actions/aws-test-infra/src/awsclient_test.go @@ -0,0 +1,28 @@ +package main + +import ( + "testing" + "time" +) + +func TestEC2WaiterAdapter_EffectiveInstanceRunningTimeout(t *testing.T) { + tests := []struct { + name string + configured time.Duration + want time.Duration + }{ + {"zero falls back to default", 0, defaultWaiterMaxWait}, + {"negative falls back to default", -5 * time.Minute, defaultWaiterMaxWait}, + {"positive value used as-is", 45 * time.Minute, 45 * time.Minute}, + {"default constant is 30 minutes", -1, 30 * time.Minute}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + a := &ec2WaiterAdapter{instanceRunningTimeout: tt.configured} + if got := a.effectiveInstanceRunningTimeout(); got != tt.want { + t.Errorf("effectiveInstanceRunningTimeout() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/.github/actions/aws-test-infra/src/cleanup.go b/.github/actions/aws-test-infra/src/cleanup.go new file mode 100644 index 0000000..ae4a954 --- /dev/null +++ b/.github/actions/aws-test-infra/src/cleanup.go @@ -0,0 +1,327 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log/slog" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" +) + +// CleanupConfig is the parsed flag set for `cleanup`. +type CleanupConfig struct { + Region string + RunID string + + // Resource IDs from a successful provision. All optional — anything + // missing falls through to the tag-based sweep. + VPCID string + IGWID string + SubnetID string + RouteTableID string + RouteAssocID string + SecurityGroupID string + InstanceIDs []string + + SkipDirect bool + SkipSweep bool + StrictSweep bool +} + +func runCleanup(ctx context.Context, logger *slog.Logger, name string, args []string) error { + fs := flag.NewFlagSet(name, flag.ContinueOnError) + cfg := CleanupConfig{} + var instanceIDsCSV string + fs.StringVar(&cfg.Region, "region", "", "AWS region (required)") + fs.StringVar(&cfg.RunID, "run-id", "", "RunID tag value to use for the fallback sweep (required)") + fs.StringVar(&cfg.VPCID, "vpc-id", "", "VPC ID") + fs.StringVar(&cfg.IGWID, "igw-id", "", "Internet gateway ID") + fs.StringVar(&cfg.SubnetID, "subnet-id", "", "Subnet ID") + fs.StringVar(&cfg.RouteTableID, "route-table-id", "", "Route table ID") + fs.StringVar(&cfg.RouteAssocID, "route-assoc-id", "", "Route table association ID") + fs.StringVar(&cfg.SecurityGroupID, "security-group-id", "", "Security group ID") + fs.StringVar(&instanceIDsCSV, "instance-ids", "", "Comma-separated list of instance IDs") + fs.BoolVar(&cfg.SkipDirect, "skip-direct", false, "Skip direct cleanup (only run the tag-based sweep)") + fs.BoolVar(&cfg.SkipSweep, "skip-sweep", false, "Skip the tag-based sweep (only run direct cleanup with the supplied IDs)") + fs.BoolVar(&cfg.StrictSweep, "strict-sweep", false, "Fail the cleanup step on sweep errors. Default: log and continue (matches the original Bash teardown's set +e behavior).") + if err := fs.Parse(args); err != nil { + return fmt.Errorf("parse cleanup flags: %w", err) + } + if err := finalizeCleanupConfig(&cfg, instanceIDsCSV); err != nil { + return err + } + + awsCfg, err := loadAWSConfig(ctx, cfg.Region) + if err != nil { + return fmt.Errorf("load aws config: %w", err) + } + c := ec2.NewFromConfig(awsCfg) + waiter := &ec2WaiterAdapter{client: c} + + return Cleanup(ctx, logger, c, waiter, cfg) +} + +// finalizeCleanupConfig validates required fields and parses raw form +// values into cfg's derived fields. Pure function — no AWS, no I/O. +func finalizeCleanupConfig(cfg *CleanupConfig, instanceIDsCSV string) error { + if cfg.Region == "" { + return errors.New("-region is required") + } + if cfg.RunID == "" && !cfg.SkipSweep { + return errors.New("-run-id is required (or pass -skip-sweep to disable the sweep)") + } + cfg.InstanceIDs = splitCSV(instanceIDsCSV) + return nil +} + +// Cleanup is the testable core of the cleanup command. It mirrors the +// existing teardown Bash exactly: best-effort direct deletion in dependency +// order, then a tag-based sweep that catches anything the direct path +// missed (typically resources from a run that failed before exporting IDs). +// +// All errors are logged but never abort the cleanup — the goal is "leave +// nothing behind on a torn-down run". +func Cleanup( + ctx context.Context, + logger *slog.Logger, + c EC2API, + waiter EC2Waiter, + cfg CleanupConfig, +) error { + if !cfg.SkipDirect { + directCleanup(ctx, logger, c, waiter, cfg) + } + if !cfg.SkipSweep { + if err := sweepByTag(ctx, logger, c, waiter, cfg.RunID); err != nil { + // The original Bash teardown ran under `set +e`, so every + // failure was silently absorbed and the step exited zero. + // We mirror that by default — workflows use `if: always()` + // for cleanup precisely so cleanup never fails the run, and + // returning sweep errors here would break that contract. + // Set `-strict-sweep` to opt back into hard failure. + if cfg.StrictSweep { + return err + } + logger.Error("sweep encountered errors; continuing because -strict-sweep is off", "err", err) + } + } + return nil +} + +func directCleanup( + ctx context.Context, + logger *slog.Logger, + c EC2API, + waiter EC2Waiter, + cfg CleanupConfig, +) { + if len(cfg.InstanceIDs) > 0 { + logger.Info("terminating instances", "ids", cfg.InstanceIDs) + if _, err := c.TerminateInstances(ctx, &ec2.TerminateInstancesInput{ + InstanceIds: cfg.InstanceIDs, + }); err != nil { + logger.Warn("terminate-instances failed", "err", err) + } else if err := waiter.WaitInstanceTerminated(ctx, cfg.InstanceIDs); err != nil { + logger.Warn("wait instance-terminated failed", "err", err) + } + } + if cfg.SecurityGroupID != "" { + logger.Info("deleting security group", "id", cfg.SecurityGroupID) + if _, err := c.DeleteSecurityGroup(ctx, &ec2.DeleteSecurityGroupInput{ + GroupId: aws.String(cfg.SecurityGroupID), + }); err != nil { + logger.Warn("delete-security-group failed", "err", err) + } + } + if cfg.RouteAssocID != "" { + if _, err := c.DisassociateRouteTable(ctx, &ec2.DisassociateRouteTableInput{ + AssociationId: aws.String(cfg.RouteAssocID), + }); err != nil { + logger.Warn("disassociate-route-table failed", "err", err) + } + } + if cfg.RouteTableID != "" { + if _, err := c.DeleteRouteTable(ctx, &ec2.DeleteRouteTableInput{ + RouteTableId: aws.String(cfg.RouteTableID), + }); err != nil { + logger.Warn("delete-route-table failed", "err", err) + } + } + if cfg.SubnetID != "" { + if _, err := c.DeleteSubnet(ctx, &ec2.DeleteSubnetInput{ + SubnetId: aws.String(cfg.SubnetID), + }); err != nil { + logger.Warn("delete-subnet failed", "err", err) + } + } + if cfg.IGWID != "" { + if cfg.VPCID != "" { + if _, err := c.DetachInternetGateway(ctx, &ec2.DetachInternetGatewayInput{ + InternetGatewayId: aws.String(cfg.IGWID), + VpcId: aws.String(cfg.VPCID), + }); err != nil { + logger.Warn("detach-internet-gateway failed", "err", err) + } + } + if _, err := c.DeleteInternetGateway(ctx, &ec2.DeleteInternetGatewayInput{ + InternetGatewayId: aws.String(cfg.IGWID), + }); err != nil { + logger.Warn("delete-internet-gateway failed", "err", err) + } + } + if cfg.VPCID != "" { + if _, err := c.DeleteVpc(ctx, &ec2.DeleteVpcInput{ + VpcId: aws.String(cfg.VPCID), + }); err != nil { + logger.Warn("delete-vpc failed", "err", err) + } + } +} + +// sweepByTag finds and deletes every resource that matches Name=tag:RunID +// Values=. The order — instances → SGs → route tables → subnets → +// IGWs → VPCs — matches the dependency chain so deletes don't fail because +// of in-use checks. +func sweepByTag(ctx context.Context, logger *slog.Logger, c EC2API, waiter EC2Waiter, runID string) error { + logger.Info("running tag-based sweep", "run_id", runID) + tagFilter := []types.Filter{{Name: aws.String("tag:RunID"), Values: []string{runID}}} + var errs []error + + // Instances + instOut, err := c.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + Filters: append(append([]types.Filter{}, tagFilter...), + types.Filter{Name: aws.String("instance-state-name"), + Values: []string{"pending", "running", "stopping", "stopped"}}), + }) + if err != nil { + errs = append(errs, fmt.Errorf("describe-instances (sweep): %w", err)) + } else { + var sweepInstances []string + for _, r := range instOut.Reservations { + for _, i := range r.Instances { + if id := aws.ToString(i.InstanceId); id != "" { + sweepInstances = append(sweepInstances, id) + } + } + } + if len(sweepInstances) > 0 { + logger.Info("sweep: terminating instances", "count", len(sweepInstances), "ids", strings.Join(sweepInstances, ",")) + if _, err := c.TerminateInstances(ctx, &ec2.TerminateInstancesInput{InstanceIds: sweepInstances}); err != nil { + logger.Warn("sweep: terminate-instances failed", "err", err) + } + if err := waiter.WaitInstanceTerminated(ctx, sweepInstances); err != nil { + logger.Warn("sweep: wait-instance-terminated failed", "err", err) + } + } + } + + // Security groups + sgOut, err := c.DescribeSecurityGroups(ctx, &ec2.DescribeSecurityGroupsInput{Filters: tagFilter}) + if err != nil { + errs = append(errs, fmt.Errorf("describe-security-groups (sweep): %w", err)) + } else { + for _, sg := range sgOut.SecurityGroups { + id := aws.ToString(sg.GroupId) + if id == "" { + continue + } + if _, err := c.DeleteSecurityGroup(ctx, &ec2.DeleteSecurityGroupInput{GroupId: aws.String(id)}); err != nil { + logger.Warn("sweep: delete-security-group failed", "id", id, "err", err) + } + } + } + + // Route tables — disassociate every association first, then delete. + rtOut, err := c.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{Filters: tagFilter}) + if err != nil { + errs = append(errs, fmt.Errorf("describe-route-tables (sweep): %w", err)) + } else { + for _, rt := range rtOut.RouteTables { + id := aws.ToString(rt.RouteTableId) + if id == "" { + continue + } + for _, a := range rt.Associations { + aid := aws.ToString(a.RouteTableAssociationId) + if aid == "" { + continue + } + if _, err := c.DisassociateRouteTable(ctx, &ec2.DisassociateRouteTableInput{ + AssociationId: aws.String(aid), + }); err != nil { + logger.Warn("sweep: disassociate-route-table failed", "id", aid, "err", err) + } + } + if _, err := c.DeleteRouteTable(ctx, &ec2.DeleteRouteTableInput{RouteTableId: aws.String(id)}); err != nil { + logger.Warn("sweep: delete-route-table failed", "id", id, "err", err) + } + } + } + + // Subnets + subOut, err := c.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{Filters: tagFilter}) + if err != nil { + errs = append(errs, fmt.Errorf("describe-subnets (sweep): %w", err)) + } else { + for _, sn := range subOut.Subnets { + id := aws.ToString(sn.SubnetId) + if id == "" { + continue + } + if _, err := c.DeleteSubnet(ctx, &ec2.DeleteSubnetInput{SubnetId: aws.String(id)}); err != nil { + logger.Warn("sweep: delete-subnet failed", "id", id, "err", err) + } + } + } + + // Internet gateways — detach from every attached VPC, then delete. + igwOut, err := c.DescribeInternetGateways(ctx, &ec2.DescribeInternetGatewaysInput{Filters: tagFilter}) + if err != nil { + errs = append(errs, fmt.Errorf("describe-internet-gateways (sweep): %w", err)) + } else { + for _, igw := range igwOut.InternetGateways { + id := aws.ToString(igw.InternetGatewayId) + if id == "" { + continue + } + for _, att := range igw.Attachments { + vid := aws.ToString(att.VpcId) + if vid == "" { + continue + } + if _, err := c.DetachInternetGateway(ctx, &ec2.DetachInternetGatewayInput{ + InternetGatewayId: aws.String(id), + VpcId: aws.String(vid), + }); err != nil { + logger.Warn("sweep: detach-internet-gateway failed", "igw", id, "vpc", vid, "err", err) + } + } + if _, err := c.DeleteInternetGateway(ctx, &ec2.DeleteInternetGatewayInput{InternetGatewayId: aws.String(id)}); err != nil { + logger.Warn("sweep: delete-internet-gateway failed", "id", id, "err", err) + } + } + } + + // VPCs + vpcOut, err := c.DescribeVpcs(ctx, &ec2.DescribeVpcsInput{Filters: tagFilter}) + if err != nil { + errs = append(errs, fmt.Errorf("describe-vpcs (sweep): %w", err)) + } else { + for _, v := range vpcOut.Vpcs { + id := aws.ToString(v.VpcId) + if id == "" { + continue + } + if _, err := c.DeleteVpc(ctx, &ec2.DeleteVpcInput{VpcId: aws.String(id)}); err != nil { + logger.Warn("sweep: delete-vpc failed", "id", id, "err", err) + } + } + } + + return errors.Join(errs...) +} diff --git a/.github/actions/aws-test-infra/src/cleanup_test.go b/.github/actions/aws-test-infra/src/cleanup_test.go new file mode 100644 index 0000000..071007c --- /dev/null +++ b/.github/actions/aws-test-infra/src/cleanup_test.go @@ -0,0 +1,356 @@ +package main + +import ( + "context" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" +) + +func TestCleanup_DirectOrdering(t *testing.T) { + // Direct cleanup must mirror the existing Bash teardown order. Subnet + // can't be deleted while instances are still attached; SG can't be + // deleted while ENIs reference it; VPC can't be deleted while it owns + // any of the above. + c := &fakeEC2{} + w := &fakeWaiter{} + cfg := CleanupConfig{ + Region: "us-west-2", + RunID: "run-42", + VPCID: "vpc-1", + IGWID: "igw-1", + SubnetID: "subnet-1", + RouteTableID: "rtb-1", + RouteAssocID: "rtbassoc-1", + SecurityGroupID: "sg-1", + InstanceIDs: []string{"i-1", "i-2", "i-3"}, + SkipSweep: true, + } + + if err := Cleanup(context.Background(), newTestLogger(), c, w, cfg); err != nil { + t.Fatalf("Cleanup: %v", err) + } + + want := []string{ + "TerminateInstances", + "DeleteSecurityGroup", + "DisassociateRouteTable", + "DeleteRouteTable", + "DeleteSubnet", + "DetachInternetGateway", + "DeleteInternetGateway", + "DeleteVpc", + } + seq := methodSequence(c.calls) + if err := requireOrdering(seq, want); err != nil { + t.Fatal(err) + } + + // Dependency-critical strict checks. AWS rejects the dependent + // delete with InUse errors if the parent isn't torn down first; + // these catch bugs that requireOrdering's subsequence match would + // silently accept. + // + // Disassociate must IMMEDIATELY precede DeleteRouteTable: any other + // call between them would mean we tried to delete the route table + // while still associated, which AWS rejects. + if err := requireImmediatelyAfter(seq, "DisassociateRouteTable", "DeleteRouteTable"); err != nil { + t.Errorf("disassociate→delete-RT not immediate: %v", err) + } + // Same for IGW: detach must immediately precede delete. + if err := requireImmediatelyAfter(seq, "DetachInternetGateway", "DeleteInternetGateway"); err != nil { + t.Errorf("detach→delete-IGW not immediate: %v", err) + } + // Subnet/SG/VPC deletes can happen only after instances are + // terminated (else "DependencyViolation" / "InvalidGroup.InUse"). + for _, after := range []string{"DeleteSecurityGroup", "DeleteSubnet", "DeleteVpc"} { + if err := requireBefore(seq, "TerminateInstances", after); err != nil { + t.Errorf("TerminateInstances must precede %s: %v", after, err) + } + } + + if len(w.termCalls) != 1 { + t.Errorf("expected 1 wait-instance-terminated call, got %d", len(w.termCalls)) + } +} + +func TestCleanup_SweepOnlyHandlesOrphans(t *testing.T) { + // A run that fails before exporting any IDs leaves no direct cleanup + // targets; the sweep must still find the orphaned resources and + // delete them in dependency order. + c := &fakeEC2{ + sweepResources: sweepFixture{ + Instances: []string{"i-orphan-1", "i-orphan-2"}, + SGs: []string{"sg-orphan"}, + RouteTables: []routeTableFixture{ + {ID: "rtb-orphan", AssociationIDs: []string{"rtbassoc-orphan"}}, + }, + Subnets: []string{"subnet-orphan"}, + IGWs: []igwFixture{ + {ID: "igw-orphan", VPCs: []string{"vpc-orphan"}}, + }, + VPCs: []string{"vpc-orphan"}, + }, + } + w := &fakeWaiter{} + cfg := CleanupConfig{Region: "us-west-2", RunID: "run-42", SkipDirect: true} + + if err := Cleanup(context.Background(), newTestLogger(), c, w, cfg); err != nil { + t.Fatalf("Cleanup: %v", err) + } + + // Ordering: instances → SG → RT (disassoc + delete) → subnet → IGW + // (detach + delete) → VPC. + want := []string{ + "DescribeInstances", + "TerminateInstances", + "DescribeSecurityGroups", + "DeleteSecurityGroup", + "DescribeRouteTables", + "DisassociateRouteTable", + "DeleteRouteTable", + "DescribeSubnets", + "DeleteSubnet", + "DescribeInternetGateways", + "DetachInternetGateway", + "DeleteInternetGateway", + "DescribeVpcs", + "DeleteVpc", + } + if err := requireOrdering(methodSequence(c.calls), want); err != nil { + t.Fatalf("sweep ordering wrong: %v", err) + } +} + +func TestCleanup_SweepFiltersByRunIDTag(t *testing.T) { + // The sweep must filter by tag:RunID = the supplied run-id. A bug + // here could cause cleanup to delete unrelated resources in the + // account. + c := &fakeEC2{} + w := &fakeWaiter{} + cfg := CleanupConfig{Region: "us-west-2", RunID: "run-42", SkipDirect: true} + + if err := Cleanup(context.Background(), newTestLogger(), c, w, cfg); err != nil { + t.Fatalf("Cleanup: %v", err) + } + + for _, call := range c.calls { + switch in := call.Input.(type) { + case *ec2.DescribeVpcsInput: + assertHasTagFilter(t, "DescribeVpcs", in.Filters, "run-42") + case *ec2.DescribeSubnetsInput: + assertHasTagFilter(t, "DescribeSubnets", in.Filters, "run-42") + case *ec2.DescribeRouteTablesInput: + assertHasTagFilter(t, "DescribeRouteTables", in.Filters, "run-42") + case *ec2.DescribeSecurityGroupsInput: + assertHasTagFilter(t, "DescribeSecurityGroups", in.Filters, "run-42") + case *ec2.DescribeInternetGatewaysInput: + assertHasTagFilter(t, "DescribeInternetGateways", in.Filters, "run-42") + case *ec2.DescribeInstancesInput: + assertHasTagFilter(t, "DescribeInstances", in.Filters, "run-42") + } + } +} + +func TestCleanup_SweepIGWDetachesEveryAttachment(t *testing.T) { + // If an orphan IGW is attached to multiple VPCs (rare but possible + // after a botched run), every attachment must be detached before the + // IGW can be deleted. + c := &fakeEC2{ + sweepResources: sweepFixture{ + IGWs: []igwFixture{ + {ID: "igw-multi", VPCs: []string{"vpc-a", "vpc-b"}}, + }, + }, + } + w := &fakeWaiter{} + cfg := CleanupConfig{Region: "us-west-2", RunID: "run-42", SkipDirect: true} + + if err := Cleanup(context.Background(), newTestLogger(), c, w, cfg); err != nil { + t.Fatalf("Cleanup: %v", err) + } + + detaches := 0 + for _, call := range c.calls { + if call.Method == "DetachInternetGateway" { + detaches++ + } + } + if detaches != 2 { + t.Errorf("expected 2 DetachInternetGateway calls (one per attached VPC), got %d", detaches) + } +} + +func TestCleanup_SkipDirect(t *testing.T) { + // With -skip-direct, supplied IDs must be ignored. The sweep still + // runs but doesn't touch them (sweep only acts on tag-discovered + // resources, and we stage no sweep resources here). + // + // We assert this by inspecting every Delete*/Terminate* call and + // confirming none were issued against the supplied direct-path IDs. + c := &fakeEC2{} + w := &fakeWaiter{} + cfg := CleanupConfig{ + Region: "us-west-2", + RunID: "run-42", + VPCID: "vpc-direct", + IGWID: "igw-direct", + SubnetID: "subnet-direct", + RouteTableID: "rtb-direct", + RouteAssocID: "rtbassoc-direct", + SecurityGroupID: "sg-direct", + InstanceIDs: []string{"i-direct-1", "i-direct-2"}, + SkipDirect: true, + } + + if err := Cleanup(context.Background(), newTestLogger(), c, w, cfg); err != nil { + t.Fatalf("Cleanup: %v", err) + } + + for _, call := range c.calls { + switch in := call.Input.(type) { + case *ec2.TerminateInstancesInput: + for _, id := range in.InstanceIds { + if strings.HasPrefix(id, "i-direct") { + t.Errorf("TerminateInstances(%s) ran despite SkipDirect=true (direct path leaked)", id) + } + } + case *ec2.DeleteSecurityGroupInput: + if aws.ToString(in.GroupId) == "sg-direct" { + t.Errorf("DeleteSecurityGroup(sg-direct) ran despite SkipDirect=true") + } + case *ec2.DeleteSubnetInput: + if aws.ToString(in.SubnetId) == "subnet-direct" { + t.Errorf("DeleteSubnet(subnet-direct) ran despite SkipDirect=true") + } + case *ec2.DeleteRouteTableInput: + if aws.ToString(in.RouteTableId) == "rtb-direct" { + t.Errorf("DeleteRouteTable(rtb-direct) ran despite SkipDirect=true") + } + case *ec2.DisassociateRouteTableInput: + if aws.ToString(in.AssociationId) == "rtbassoc-direct" { + t.Errorf("DisassociateRouteTable(rtbassoc-direct) ran despite SkipDirect=true") + } + case *ec2.DeleteInternetGatewayInput: + if aws.ToString(in.InternetGatewayId) == "igw-direct" { + t.Errorf("DeleteInternetGateway(igw-direct) ran despite SkipDirect=true") + } + case *ec2.DetachInternetGatewayInput: + if aws.ToString(in.InternetGatewayId) == "igw-direct" { + t.Errorf("DetachInternetGateway(igw-direct) ran despite SkipDirect=true") + } + case *ec2.DeleteVpcInput: + if aws.ToString(in.VpcId) == "vpc-direct" { + t.Errorf("DeleteVpc(vpc-direct) ran despite SkipDirect=true") + } + } + } +} + +func TestCleanup_SweepErrorPropagationAtEachStage(t *testing.T) { + // sweepByTag has six Describe* call sites, any of which can fail + // (auth expiring, throttling, etc.). The default behavior is to log + // + continue (matches Bash `set +e`); -strict-sweep reverses that. + // This table covers each stage × strict mode so the contract is + // pinned end-to-end, not just for DescribeInstances. + stages := []string{ + "DescribeInstances", + "DescribeSecurityGroups", + "DescribeRouteTables", + "DescribeSubnets", + "DescribeInternetGateways", + "DescribeVpcs", + } + for _, stage := range stages { + for _, strict := range []bool{false, true} { + stage, strict := stage, strict + name := stage + if strict { + name += "/strict" + } else { + name += "/default" + } + t.Run(name, func(t *testing.T) { + c := &fakeEC2{failOn: map[string]error{stage: errStaged}} + w := &fakeWaiter{} + cfg := CleanupConfig{ + Region: "us-west-2", + RunID: "run-42", + SkipDirect: true, + StrictSweep: strict, + } + err := Cleanup(context.Background(), newTestLogger(), c, w, cfg) + if strict && err == nil { + t.Errorf("strict-sweep on with %s failure: expected error, got nil", stage) + } + if !strict && err != nil { + t.Errorf("strict-sweep off with %s failure: expected nil, got %v", stage, err) + } + }) + } + } +} + +func TestCleanup_IdempotentOnEmpty(t *testing.T) { + // Cleanup with no IDs and no orphans should succeed without error. + c := &fakeEC2{} + w := &fakeWaiter{} + cfg := CleanupConfig{Region: "us-west-2", RunID: "run-42"} + + if err := Cleanup(context.Background(), newTestLogger(), c, w, cfg); err != nil { + t.Fatalf("Cleanup on empty state errored: %v", err) + } +} + +func TestCleanup_PartialIDsDoNotPanicOnMissingFields(t *testing.T) { + // Some workflow runs fail mid-provision and only export half the + // IDs. Cleanup must tolerate any subset. + c := &fakeEC2{} + w := &fakeWaiter{} + cfg := CleanupConfig{ + Region: "us-west-2", + RunID: "run-42", + VPCID: "vpc-1", // present + SubnetID: "subnet-1", // present + // IGWID intentionally empty + // RouteTableID intentionally empty + SkipSweep: true, + } + + if err := Cleanup(context.Background(), newTestLogger(), c, w, cfg); err != nil { + t.Fatalf("Cleanup with partial IDs errored: %v", err) + } + + // Should call DeleteSubnet and DeleteVpc but not DeleteInternetGateway + for _, call := range c.calls { + if call.Method == "DeleteInternetGateway" { + t.Errorf("DeleteInternetGateway called despite empty IGWID") + } + if call.Method == "DeleteRouteTable" { + t.Errorf("DeleteRouteTable called despite empty RouteTableID") + } + } +} + +// helpers ───────────────────────────────────────────────────────────────── + +func assertHasTagFilter(t *testing.T, label string, filters []types.Filter, want string) { + t.Helper() + for _, f := range filters { + if aws.ToString(f.Name) != "tag:RunID" { + continue + } + for _, v := range f.Values { + if v == want { + return + } + } + } + t.Errorf("%s: filter list missing tag:RunID = %s", label, want) +} + +// silence the unused-import warning from ec2 that the test cases use only via interface +var _ = ec2.NewFromConfig diff --git a/.github/actions/aws-test-infra/src/go.mod b/.github/actions/aws-test-infra/src/go.mod new file mode 100644 index 0000000..af8e998 --- /dev/null +++ b/.github/actions/aws-test-infra/src/go.mod @@ -0,0 +1,25 @@ +module github.com/loft-sh/github-actions/aws-test-infra + +go 1.24 + +require ( + github.com/aws/aws-sdk-go-v2 v1.41.7 + github.com/aws/aws-sdk-go-v2/config v1.32.17 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.300.0 + github.com/aws/aws-sdk-go-v2/service/ssm v1.68.6 +) + +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect + github.com/aws/smithy-go v1.25.1 // indirect +) diff --git a/.github/actions/aws-test-infra/src/go.sum b/.github/actions/aws-test-infra/src/go.sum new file mode 100644 index 0000000..f46fed5 --- /dev/null +++ b/.github/actions/aws-test-infra/src/go.sum @@ -0,0 +1,32 @@ +github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8= +github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc= +github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU= +github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU= +github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.300.0 h1:HgOfUy9Sm2Q9UQAyj9I/7NZhIaymTEakGA/FnLw65lw= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.300.0/go.mod h1:Y95W0Hm6FYLPa6o0hbnJ+sWgmdc4ifcLFjGkdobWVhY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI= +github.com/aws/aws-sdk-go-v2/service/ssm v1.68.6 h1:0LPJjbSNEDHidGOXa0LfvSVbdn9/GdlJUQTgE0kFpso= +github.com/aws/aws-sdk-go-v2/service/ssm v1.68.6/go.mod h1:SrZAopBP5/lyQ6NBVXKlRp8wPIXhzBCZU98sEozmv8Y= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg= +github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk= +github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio= +github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI= +github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= diff --git a/.github/actions/aws-test-infra/src/ingress.go b/.github/actions/aws-test-infra/src/ingress.go new file mode 100644 index 0000000..7e6e77f --- /dev/null +++ b/.github/actions/aws-test-infra/src/ingress.go @@ -0,0 +1,85 @@ +package main + +import ( + "fmt" + "strconv" + "strings" +) + +// IngressRule describes a single security-group ingress permission. +// +// Encoded as a colon-delimited string for CLI use: +// +// ::: +// +// Examples (lifted from the existing workflows): +// +// -1:-1:-1:10.0.0.0/16 (intra-VPC, all protocols) +// tcp:8443:8443:0.0.0.0/0 (vCluster API, wide-open) +// tcp:30000:32767:1.2.3.4/32 (NodePort range, runner-only) +// icmp:-1:-1:10.0.0.0/16 (ICMP intra-VPC) +// +// Protocol "-1" means "all protocols". When protocol is "-1" or "icmp", AWS +// requires fromPort/toPort to be -1 (the workflow Bash sets them to -1 in +// these cases too). +type IngressRule struct { + Protocol string + FromPort int32 + ToPort int32 + CIDR string +} + +func parseIngressRule(s string) (IngressRule, error) { + parts := strings.SplitN(s, ":", 4) + if len(parts) != 4 { + return IngressRule{}, fmt.Errorf("ingress rule %q: expected protocol:fromPort:toPort:cidr", s) + } + + from, err := strconv.ParseInt(parts[1], 10, 32) + if err != nil { + return IngressRule{}, fmt.Errorf("ingress rule %q: parse fromPort: %w", s, err) + } + to, err := strconv.ParseInt(parts[2], 10, 32) + if err != nil { + return IngressRule{}, fmt.Errorf("ingress rule %q: parse toPort: %w", s, err) + } + + if parts[0] == "" { + return IngressRule{}, fmt.Errorf("ingress rule %q: protocol is empty", s) + } + if parts[3] == "" { + return IngressRule{}, fmt.Errorf("ingress rule %q: cidr is empty", s) + } + + return IngressRule{ + Protocol: parts[0], + FromPort: int32(from), + ToPort: int32(to), + CIDR: parts[3], + }, nil +} + +// ingressFlag implements flag.Value, allowing -ingress to be repeated. +type ingressFlag struct { + rules *[]IngressRule +} + +func (f *ingressFlag) String() string { + if f.rules == nil { + return "" + } + parts := make([]string, 0, len(*f.rules)) + for _, r := range *f.rules { + parts = append(parts, fmt.Sprintf("%s:%d:%d:%s", r.Protocol, r.FromPort, r.ToPort, r.CIDR)) + } + return strings.Join(parts, ",") +} + +func (f *ingressFlag) Set(value string) error { + rule, err := parseIngressRule(value) + if err != nil { + return err + } + *f.rules = append(*f.rules, rule) + return nil +} diff --git a/.github/actions/aws-test-infra/src/ingress_test.go b/.github/actions/aws-test-infra/src/ingress_test.go new file mode 100644 index 0000000..66188ec --- /dev/null +++ b/.github/actions/aws-test-infra/src/ingress_test.go @@ -0,0 +1,75 @@ +package main + +import "testing" + +func TestParseIngressRule(t *testing.T) { + tests := []struct { + name string + input string + want IngressRule + wantErr bool + }{ + { + name: "intra-vpc all protocols", + input: "-1:-1:-1:10.0.0.0/16", + want: IngressRule{Protocol: "-1", FromPort: -1, ToPort: -1, CIDR: "10.0.0.0/16"}, + }, + { + name: "tcp wide-open vCluster API", + input: "tcp:8443:8443:0.0.0.0/0", + want: IngressRule{Protocol: "tcp", FromPort: 8443, ToPort: 8443, CIDR: "0.0.0.0/0"}, + }, + { + name: "tcp NodePort range scoped to runner", + input: "tcp:30000:32767:1.2.3.4/32", + want: IngressRule{Protocol: "tcp", FromPort: 30000, ToPort: 32767, CIDR: "1.2.3.4/32"}, + }, + { + name: "icmp intra-vpc", + input: "icmp:-1:-1:10.0.0.0/16", + want: IngressRule{Protocol: "icmp", FromPort: -1, ToPort: -1, CIDR: "10.0.0.0/16"}, + }, + {name: "missing fields", input: "tcp:8443:8443", wantErr: true}, + {name: "empty protocol", input: ":8443:8443:0.0.0.0/0", wantErr: true}, + {name: "empty cidr", input: "tcp:8443:8443:", wantErr: true}, + {name: "non-numeric port", input: "tcp:abc:8443:0.0.0.0/0", wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got, err := parseIngressRule(tt.input) + if (err != nil) != tt.wantErr { + t.Fatalf("parseIngressRule(%q) err=%v wantErr=%v", tt.input, err, tt.wantErr) + } + if tt.wantErr { + return + } + if got != tt.want { + t.Errorf("parseIngressRule(%q) = %+v, want %+v", tt.input, got, tt.want) + } + }) + } +} + +func TestIngressFlagAccumulates(t *testing.T) { + var rules []IngressRule + f := ingressFlag{rules: &rules} + + for _, in := range []string{ + "-1:-1:-1:10.0.0.0/16", + "tcp:8443:8443:0.0.0.0/0", + "icmp:-1:-1:10.0.0.0/16", + } { + if err := f.Set(in); err != nil { + t.Fatalf("Set(%q): %v", in, err) + } + } + + if len(rules) != 3 { + t.Fatalf("got %d rules, want 3", len(rules)) + } + if rules[1].Protocol != "tcp" || rules[1].FromPort != 8443 { + t.Errorf("rule[1] mis-parsed: %+v", rules[1]) + } +} diff --git a/.github/actions/aws-test-infra/src/main.go b/.github/actions/aws-test-infra/src/main.go new file mode 100644 index 0000000..e873797 --- /dev/null +++ b/.github/actions/aws-test-infra/src/main.go @@ -0,0 +1,68 @@ +// Command aws-test-infra provisions and tears down AWS test infrastructure +// (VPC + subnet + IGW + route table + security group + EC2 instances) for +// use by GitHub Actions e2e workflows. It replaces hundreds of lines of +// duplicated Bash + aws-cli that previously lived inline in two +// vcluster-pro workflows. +// +// Two subcommands: +// +// aws-test-infra provision [flags] +// aws-test-infra cleanup [flags] +// +// Both rely on the default aws-sdk-go-v2 credential chain. Workflows that +// already use aws-actions/configure-aws-credentials (OIDC + assume-role) +// pass credentials in via env vars with no extra wiring. +package main + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "os/signal" + "syscall" +) + +func main() { + if err := run(context.Background(), os.Stderr, os.Args); err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err) + os.Exit(1) + } +} + +func run(ctx context.Context, stderr io.Writer, args []string) error { + if len(args) < 2 { + printUsage(stderr) + return errors.New("subcommand required") + } + + logger := slog.New(slog.NewTextHandler(stderr, &slog.HandlerOptions{Level: slog.LevelInfo})) + + ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) + defer stop() + + switch args[1] { + case "provision": + return runProvision(ctx, logger, args[0]+" provision", args[2:]) + case "cleanup": + return runCleanup(ctx, logger, args[0]+" cleanup", args[2:]) + case "-h", "--help", "help": + printUsage(stderr) + return nil + default: + printUsage(stderr) + return fmt.Errorf("unknown subcommand: %s", args[1]) + } +} + +func printUsage(w io.Writer) { + fmt.Fprintln(w, `Usage: aws-test-infra [flags] + +Subcommands: + provision Create VPC, subnet, IGW, route table, security group, and EC2 instances + cleanup Tear down resources by ID and run a tag-based fallback sweep + +Run "aws-test-infra -h" for subcommand flags.`) +} diff --git a/.github/actions/aws-test-infra/src/mocks_test.go b/.github/actions/aws-test-infra/src/mocks_test.go new file mode 100644 index 0000000..592d863 --- /dev/null +++ b/.github/actions/aws-test-infra/src/mocks_test.go @@ -0,0 +1,438 @@ +package main + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" +) + +// fakeEC2 is a hand-rolled mock satisfying EC2API. Every method records +// itself in calls (so tests can assert ordering) and returns happy-path +// canned responses unless an error is staged via failOn. +type fakeEC2 struct { + mu sync.Mutex + calls []apiCall + + failOn map[string]error // method name → error to return + + // When sweepResources is set, the Describe* methods used by the sweep + // return resources tagged with the matching RunID. Used by cleanup + // tests to feed orphaned resources into the sweep path. + sweepResources sweepFixture +} + +type apiCall struct { + Method string + Input interface{} + TagSpec []types.TagSpecification +} + +type sweepFixture struct { + Instances []string + SGs []string + RouteTables []routeTableFixture + Subnets []string + IGWs []igwFixture + VPCs []string +} + +type routeTableFixture struct { + ID string + AssociationIDs []string +} + +type igwFixture struct { + ID string + VPCs []string +} + +func (f *fakeEC2) record(method string, input interface{}, ts []types.TagSpecification) { + f.mu.Lock() + defer f.mu.Unlock() + f.calls = append(f.calls, apiCall{Method: method, Input: input, TagSpec: ts}) +} + +func (f *fakeEC2) shouldFail(method string) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.failOn == nil { + return nil + } + return f.failOn[method] +} + +// methods executed in order ───────────────────────────────────────────── + +func (f *fakeEC2) CreateVpc(_ context.Context, in *ec2.CreateVpcInput, _ ...func(*ec2.Options)) (*ec2.CreateVpcOutput, error) { + f.record("CreateVpc", in, in.TagSpecifications) + if err := f.shouldFail("CreateVpc"); err != nil { + return nil, err + } + return &ec2.CreateVpcOutput{Vpc: &types.Vpc{VpcId: aws.String("vpc-mock")}}, nil +} +func (f *fakeEC2) ModifyVpcAttribute(_ context.Context, in *ec2.ModifyVpcAttributeInput, _ ...func(*ec2.Options)) (*ec2.ModifyVpcAttributeOutput, error) { + f.record("ModifyVpcAttribute", in, nil) + return &ec2.ModifyVpcAttributeOutput{}, f.shouldFail("ModifyVpcAttribute") +} +func (f *fakeEC2) DeleteVpc(_ context.Context, in *ec2.DeleteVpcInput, _ ...func(*ec2.Options)) (*ec2.DeleteVpcOutput, error) { + f.record("DeleteVpc", in, nil) + return &ec2.DeleteVpcOutput{}, f.shouldFail("DeleteVpc") +} +func (f *fakeEC2) DescribeVpcs(_ context.Context, in *ec2.DescribeVpcsInput, _ ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) { + f.record("DescribeVpcs", in, nil) + if err := f.shouldFail("DescribeVpcs"); err != nil { + return nil, err + } + out := &ec2.DescribeVpcsOutput{} + for _, id := range f.sweepResources.VPCs { + id := id + out.Vpcs = append(out.Vpcs, types.Vpc{VpcId: aws.String(id)}) + } + return out, nil +} + +func (f *fakeEC2) CreateInternetGateway(_ context.Context, in *ec2.CreateInternetGatewayInput, _ ...func(*ec2.Options)) (*ec2.CreateInternetGatewayOutput, error) { + f.record("CreateInternetGateway", in, in.TagSpecifications) + return &ec2.CreateInternetGatewayOutput{InternetGateway: &types.InternetGateway{InternetGatewayId: aws.String("igw-mock")}}, f.shouldFail("CreateInternetGateway") +} +func (f *fakeEC2) AttachInternetGateway(_ context.Context, in *ec2.AttachInternetGatewayInput, _ ...func(*ec2.Options)) (*ec2.AttachInternetGatewayOutput, error) { + f.record("AttachInternetGateway", in, nil) + return &ec2.AttachInternetGatewayOutput{}, f.shouldFail("AttachInternetGateway") +} +func (f *fakeEC2) DetachInternetGateway(_ context.Context, in *ec2.DetachInternetGatewayInput, _ ...func(*ec2.Options)) (*ec2.DetachInternetGatewayOutput, error) { + f.record("DetachInternetGateway", in, nil) + return &ec2.DetachInternetGatewayOutput{}, f.shouldFail("DetachInternetGateway") +} +func (f *fakeEC2) DeleteInternetGateway(_ context.Context, in *ec2.DeleteInternetGatewayInput, _ ...func(*ec2.Options)) (*ec2.DeleteInternetGatewayOutput, error) { + f.record("DeleteInternetGateway", in, nil) + return &ec2.DeleteInternetGatewayOutput{}, f.shouldFail("DeleteInternetGateway") +} +func (f *fakeEC2) DescribeInternetGateways(_ context.Context, in *ec2.DescribeInternetGatewaysInput, _ ...func(*ec2.Options)) (*ec2.DescribeInternetGatewaysOutput, error) { + f.record("DescribeInternetGateways", in, nil) + if err := f.shouldFail("DescribeInternetGateways"); err != nil { + return nil, err + } + out := &ec2.DescribeInternetGatewaysOutput{} + for _, ig := range f.sweepResources.IGWs { + ig := ig + atts := make([]types.InternetGatewayAttachment, 0, len(ig.VPCs)) + for _, v := range ig.VPCs { + v := v + atts = append(atts, types.InternetGatewayAttachment{VpcId: aws.String(v)}) + } + out.InternetGateways = append(out.InternetGateways, types.InternetGateway{ + InternetGatewayId: aws.String(ig.ID), + Attachments: atts, + }) + } + return out, nil +} + +func (f *fakeEC2) DescribeAvailabilityZones(_ context.Context, in *ec2.DescribeAvailabilityZonesInput, _ ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) { + f.record("DescribeAvailabilityZones", in, nil) + return &ec2.DescribeAvailabilityZonesOutput{ + AvailabilityZones: []types.AvailabilityZone{{ZoneName: aws.String("us-west-2a")}}, + }, nil +} +func (f *fakeEC2) CreateSubnet(_ context.Context, in *ec2.CreateSubnetInput, _ ...func(*ec2.Options)) (*ec2.CreateSubnetOutput, error) { + f.record("CreateSubnet", in, in.TagSpecifications) + return &ec2.CreateSubnetOutput{Subnet: &types.Subnet{SubnetId: aws.String("subnet-mock")}}, f.shouldFail("CreateSubnet") +} +func (f *fakeEC2) ModifySubnetAttribute(_ context.Context, in *ec2.ModifySubnetAttributeInput, _ ...func(*ec2.Options)) (*ec2.ModifySubnetAttributeOutput, error) { + f.record("ModifySubnetAttribute", in, nil) + return &ec2.ModifySubnetAttributeOutput{}, f.shouldFail("ModifySubnetAttribute") +} +func (f *fakeEC2) DeleteSubnet(_ context.Context, in *ec2.DeleteSubnetInput, _ ...func(*ec2.Options)) (*ec2.DeleteSubnetOutput, error) { + f.record("DeleteSubnet", in, nil) + return &ec2.DeleteSubnetOutput{}, f.shouldFail("DeleteSubnet") +} +func (f *fakeEC2) DescribeSubnets(_ context.Context, in *ec2.DescribeSubnetsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) { + f.record("DescribeSubnets", in, nil) + if err := f.shouldFail("DescribeSubnets"); err != nil { + return nil, err + } + out := &ec2.DescribeSubnetsOutput{} + for _, id := range f.sweepResources.Subnets { + id := id + out.Subnets = append(out.Subnets, types.Subnet{SubnetId: aws.String(id)}) + } + return out, nil +} + +func (f *fakeEC2) CreateRouteTable(_ context.Context, in *ec2.CreateRouteTableInput, _ ...func(*ec2.Options)) (*ec2.CreateRouteTableOutput, error) { + f.record("CreateRouteTable", in, in.TagSpecifications) + return &ec2.CreateRouteTableOutput{RouteTable: &types.RouteTable{RouteTableId: aws.String("rtb-mock")}}, f.shouldFail("CreateRouteTable") +} +func (f *fakeEC2) CreateRoute(_ context.Context, in *ec2.CreateRouteInput, _ ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) { + f.record("CreateRoute", in, nil) + return &ec2.CreateRouteOutput{Return: aws.Bool(true)}, f.shouldFail("CreateRoute") +} +func (f *fakeEC2) AssociateRouteTable(_ context.Context, in *ec2.AssociateRouteTableInput, _ ...func(*ec2.Options)) (*ec2.AssociateRouteTableOutput, error) { + f.record("AssociateRouteTable", in, nil) + return &ec2.AssociateRouteTableOutput{AssociationId: aws.String("rtbassoc-mock")}, f.shouldFail("AssociateRouteTable") +} +func (f *fakeEC2) DisassociateRouteTable(_ context.Context, in *ec2.DisassociateRouteTableInput, _ ...func(*ec2.Options)) (*ec2.DisassociateRouteTableOutput, error) { + f.record("DisassociateRouteTable", in, nil) + return &ec2.DisassociateRouteTableOutput{}, f.shouldFail("DisassociateRouteTable") +} +func (f *fakeEC2) DeleteRouteTable(_ context.Context, in *ec2.DeleteRouteTableInput, _ ...func(*ec2.Options)) (*ec2.DeleteRouteTableOutput, error) { + f.record("DeleteRouteTable", in, nil) + return &ec2.DeleteRouteTableOutput{}, f.shouldFail("DeleteRouteTable") +} +func (f *fakeEC2) DescribeRouteTables(_ context.Context, in *ec2.DescribeRouteTablesInput, _ ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) { + f.record("DescribeRouteTables", in, nil) + if err := f.shouldFail("DescribeRouteTables"); err != nil { + return nil, err + } + out := &ec2.DescribeRouteTablesOutput{} + for _, rt := range f.sweepResources.RouteTables { + rt := rt + assocs := make([]types.RouteTableAssociation, 0, len(rt.AssociationIDs)) + for _, aid := range rt.AssociationIDs { + aid := aid + assocs = append(assocs, types.RouteTableAssociation{RouteTableAssociationId: aws.String(aid)}) + } + out.RouteTables = append(out.RouteTables, types.RouteTable{ + RouteTableId: aws.String(rt.ID), + Associations: assocs, + }) + } + return out, nil +} + +func (f *fakeEC2) CreateSecurityGroup(_ context.Context, in *ec2.CreateSecurityGroupInput, _ ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) { + f.record("CreateSecurityGroup", in, in.TagSpecifications) + return &ec2.CreateSecurityGroupOutput{GroupId: aws.String("sg-mock")}, f.shouldFail("CreateSecurityGroup") +} +func (f *fakeEC2) AuthorizeSecurityGroupIngress(_ context.Context, in *ec2.AuthorizeSecurityGroupIngressInput, _ ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { + f.record("AuthorizeSecurityGroupIngress", in, nil) + return &ec2.AuthorizeSecurityGroupIngressOutput{}, f.shouldFail("AuthorizeSecurityGroupIngress") +} +func (f *fakeEC2) DeleteSecurityGroup(_ context.Context, in *ec2.DeleteSecurityGroupInput, _ ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) { + f.record("DeleteSecurityGroup", in, nil) + return &ec2.DeleteSecurityGroupOutput{}, f.shouldFail("DeleteSecurityGroup") +} +func (f *fakeEC2) DescribeSecurityGroups(_ context.Context, in *ec2.DescribeSecurityGroupsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + f.record("DescribeSecurityGroups", in, nil) + if err := f.shouldFail("DescribeSecurityGroups"); err != nil { + return nil, err + } + out := &ec2.DescribeSecurityGroupsOutput{} + for _, id := range f.sweepResources.SGs { + id := id + out.SecurityGroups = append(out.SecurityGroups, types.SecurityGroup{GroupId: aws.String(id)}) + } + return out, nil +} + +func (f *fakeEC2) DescribeImages(_ context.Context, in *ec2.DescribeImagesInput, _ ...func(*ec2.Options)) (*ec2.DescribeImagesOutput, error) { + f.record("DescribeImages", in, nil) + if err := f.shouldFail("DescribeImages"); err != nil { + return nil, err + } + return &ec2.DescribeImagesOutput{ + Images: []types.Image{ + {ImageId: aws.String("ami-old"), CreationDate: aws.String("2024-01-01T00:00:00.000Z")}, + {ImageId: aws.String("ami-newest"), CreationDate: aws.String("2026-01-01T00:00:00.000Z")}, + {ImageId: aws.String("ami-mid"), CreationDate: aws.String("2025-06-01T00:00:00.000Z")}, + }, + }, nil +} +func (f *fakeEC2) RunInstances(_ context.Context, in *ec2.RunInstancesInput, _ ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) { + f.record("RunInstances", in, in.TagSpecifications) + if err := f.shouldFail("RunInstances"); err != nil { + return nil, err + } + // Synthesize a unique instance ID from the role tag so tests can map + // instance → role even though our fake is stateless. + role := "unknown" + for _, ts := range in.TagSpecifications { + if ts.ResourceType != types.ResourceTypeInstance { + continue + } + for _, t := range ts.Tags { + if aws.ToString(t.Key) == "Role" { + role = aws.ToString(t.Value) + } + } + } + return &ec2.RunInstancesOutput{ + Instances: []types.Instance{{InstanceId: aws.String("i-" + role)}}, + }, nil +} +func (f *fakeEC2) TerminateInstances(_ context.Context, in *ec2.TerminateInstancesInput, _ ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) { + f.record("TerminateInstances", in, nil) + return &ec2.TerminateInstancesOutput{}, f.shouldFail("TerminateInstances") +} +func (f *fakeEC2) DescribeInstances(_ context.Context, in *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + f.record("DescribeInstances", in, nil) + if err := f.shouldFail("DescribeInstances"); err != nil { + return nil, err + } + // If the test staged sweep instances, return them grouped under one + // reservation. Otherwise return the input IDs with a stub public IP + // (covers the "describe primary" call after RunInstances). + out := &ec2.DescribeInstancesOutput{} + if len(f.sweepResources.Instances) > 0 { + // Only return sweep instances if filtered by tag — i.e. when the + // caller passes Filters (not InstanceIds). + if len(in.Filters) > 0 { + res := types.Reservation{} + for _, id := range f.sweepResources.Instances { + id := id + res.Instances = append(res.Instances, types.Instance{InstanceId: aws.String(id)}) + } + out.Reservations = []types.Reservation{res} + return out, nil + } + } + res := types.Reservation{} + for _, id := range in.InstanceIds { + id := id + res.Instances = append(res.Instances, types.Instance{ + InstanceId: aws.String(id), + PublicIpAddress: aws.String("203.0.113.1"), + }) + } + out.Reservations = []types.Reservation{res} + return out, nil +} + +// fakeSSM ───────────────────────────────────────────────────────────────── + +type fakeSSM struct { + mu sync.Mutex + calls int + online int // how many to report online from the next call onward +} + +func (s *fakeSSM) DescribeInstanceInformation(_ context.Context, _ *ssm.DescribeInstanceInformationInput, _ ...func(*ssm.Options)) (*ssm.DescribeInstanceInformationOutput, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.calls++ + out := &ssm.DescribeInstanceInformationOutput{} + for i := 0; i < s.online; i++ { + out.InstanceInformationList = append(out.InstanceInformationList, ssmtypes.InstanceInformation{ + PingStatus: ssmtypes.PingStatusOnline, + }) + } + return out, nil +} + +// fakeWaiter ───────────────────────────────────────────────────────────── + +type fakeWaiter struct { + runErr error + terminatedErr error + runCalls [][]string + termCalls [][]string +} + +func (w *fakeWaiter) WaitInstanceRunning(_ context.Context, ids []string) error { + w.runCalls = append(w.runCalls, append([]string(nil), ids...)) + return w.runErr +} + +func (w *fakeWaiter) WaitInstanceTerminated(_ context.Context, ids []string) error { + w.termCalls = append(w.termCalls, append([]string(nil), ids...)) + return w.terminatedErr +} + +// helpers ──────────────────────────────────────────────────────────────── + +func methodSequence(calls []apiCall) []string { + out := make([]string, 0, len(calls)) + for _, c := range calls { + out = append(out, c.Method) + } + return out +} + +// requireOrdering asserts that `expected` appears as a subsequence of +// methodSequence — i.e. each entry occurs after the prior in the recorded +// call list. Doesn't reject extra interleaved calls. +// +// Use this for the loose "these calls must happen, roughly in this order" +// shape. For dependency-critical pairs (terminate→wait, disassoc→delete-RT, +// detach→delete-IGW, instance-termination must precede any subnet/SG/VPC +// delete), use requireImmediatelyAfter instead — it catches insertion-of- +// wrong-call-between bugs that requireOrdering misses. +func requireOrdering(actual, expected []string) error { + idx := 0 + for _, m := range actual { + if idx < len(expected) && m == expected[idx] { + idx++ + } + } + if idx == len(expected) { + return nil + } + return fmt.Errorf("expected ordering %v not found in %v (got %d/%d)", expected, actual, idx, len(expected)) +} + +// requireImmediatelyAfter asserts that every occurrence of method `before` +// is immediately followed by method `after` in the recorded call list, +// with no other calls between them. +// +// Use for pairs where any intervening call would be a real bug: +// - TerminateInstances → WaitInstanceTerminated (waiting after the call) +// - DisassociateRouteTable → DeleteRouteTable (RT can't be deleted while associated) +// - DetachInternetGateway → DeleteInternetGateway (IGW can't be deleted while attached) +func requireImmediatelyAfter(actual []string, before, after string) error { + for i, m := range actual { + if m != before { + continue + } + if i+1 >= len(actual) { + return fmt.Errorf("%s at index %d has no following call; expected %s", before, i, after) + } + if actual[i+1] != after { + return fmt.Errorf("%s at index %d followed by %s, expected %s (full sequence: %v)", before, i, actual[i+1], after, actual) + } + } + return nil +} + +// requireBefore asserts that every occurrence of `early` happens before +// every occurrence of `late` in the recorded call list. Allows interleaved +// other calls — just enforces a partial order. +// +// Use for "X must finish before any Y starts" — e.g. +// WaitInstanceTerminated must precede every DeleteSecurityGroup / +// DeleteSubnet / DeleteVpc, or those deletes will fail with InUse errors +// in production. +// +// If `late` appears in the sequence but `early` never does, that's a +// failure: the dependency couldn't possibly have been satisfied. (The +// previous version silently passed in this case — a footgun in tests +// where the `early` step might be absent due to some other bug.) +func requireBefore(actual []string, early, late string) error { + lastEarly := -1 + lateIndices := []int{} + for i, m := range actual { + if m == early { + lastEarly = i + } + if m == late { + lateIndices = append(lateIndices, i) + } + } + if len(lateIndices) > 0 && lastEarly == -1 { + return fmt.Errorf("%s appears in sequence at index %d but %s never does (full sequence: %v)", late, lateIndices[0], early, actual) + } + for _, i := range lateIndices { + if i < lastEarly { + return fmt.Errorf("%s at index %d happens before final %s at index %d (full sequence: %v)", late, i, early, lastEarly, actual) + } + } + return nil +} + +var errStaged = errors.New("staged failure") diff --git a/.github/actions/aws-test-infra/src/output.go b/.github/actions/aws-test-infra/src/output.go new file mode 100644 index 0000000..ae3cbe6 --- /dev/null +++ b/.github/actions/aws-test-infra/src/output.go @@ -0,0 +1,118 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "log/slog" + "os" + "strings" +) + +// emitOutput writes ResourceIDs to the destination chosen by the caller. +// +// destination semantics: +// +// - "" (empty) → JSON to stdout +// - $GITHUB_OUTPUT path → key=value lines (consumable by `${{ steps.x.outputs.* }}`) +// - $GITHUB_ENV path → key=value lines (visible to subsequent steps as env vars) +// - any other path → file in the chosen format +// +// format chooses the encoding when the destination doesn't make it obvious: +// +// - "auto" → infer from destination path +// - "github-output" / "github-env" → key=value lines (identical encoding; +// the only difference is which file the action runner reads them from) +// - "json" → pretty-printed JSON +func emitOutput(logger *slog.Logger, destination, format string, ids ResourceIDs) error { + if format == "" { + format = "auto" + } + if format == "auto" { + switch { + case destination == "": + format = "json" + case strings.HasSuffix(destination, "GITHUB_OUTPUT") || strings.Contains(destination, "/runner/file_commands/set_output"): + format = "github-output" + case strings.HasSuffix(destination, "GITHUB_ENV") || strings.Contains(destination, "/runner/file_commands/set_env"): + format = "github-env" + default: + format = "json" + } + } + + var w io.Writer + if destination == "" { + w = os.Stdout + } else { + // GITHUB_OUTPUT / GITHUB_ENV are append-mode files per Actions docs. + f, err := os.OpenFile(destination, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("open output destination %q: %w", destination, err) + } + defer f.Close() + w = f + } + + switch format { + case "json": + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + if err := enc.Encode(ids); err != nil { + return fmt.Errorf("encode json output: %w", err) + } + case "github-output", "github-env": + if err := writeKeyValuePairs(w, ids); err != nil { + return fmt.Errorf("write key-value output: %w", err) + } + default: + return fmt.Errorf("unknown output format: %s", format) + } + + logger.Info("emitted output", "destination", destinationLabel(destination), "format", format) + return nil +} + +func destinationLabel(d string) string { + if d == "" { + return "stdout" + } + return d +} + +func writeKeyValuePairs(w io.Writer, ids ResourceIDs) error { + pairs := []struct { + k, v string + }{ + {"vpc_id", ids.VPCID}, + {"igw_id", ids.IGWID}, + {"subnet_id", ids.SubnetID}, + {"route_table_id", ids.RouteTableID}, + {"route_assoc_id", ids.RouteAssocID}, + {"security_group_id", ids.SecurityGroupID}, + {"ami_id", ids.AMIID}, + {"primary_public_ip", ids.PrimaryPublicIP}, + {"instance_ids", strings.Join(ids.InstanceIDs, ",")}, + } + for _, p := range pairs { + if _, err := fmt.Fprintf(w, "%s=%s\n", p.k, p.v); err != nil { + return err + } + } + for role, id := range ids.InstanceIDByRole { + if _, err := fmt.Fprintf(w, "instance_id_%s=%s\n", role, id); err != nil { + return err + } + } + // JSON-map of role → instance ID. Consumers with arbitrary role names + // (anything other than primary/worker1/worker2) read this with + // `${{ fromJSON(steps.provision.outputs.instance-id-by-role). }}`. + roleJSON, err := json.Marshal(ids.InstanceIDByRole) + if err != nil { + return fmt.Errorf("encode instance-id-by-role: %w", err) + } + if _, err := fmt.Fprintf(w, "instance_id_by_role=%s\n", roleJSON); err != nil { + return err + } + return nil +} diff --git a/.github/actions/aws-test-infra/src/output_test.go b/.github/actions/aws-test-infra/src/output_test.go new file mode 100644 index 0000000..1423cf3 --- /dev/null +++ b/.github/actions/aws-test-infra/src/output_test.go @@ -0,0 +1,172 @@ +package main + +import ( + "encoding/json" + "io" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" +) + +func newTestLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func sampleIDs() ResourceIDs { + return ResourceIDs{ + VPCID: "vpc-abc", + IGWID: "igw-def", + SubnetID: "subnet-001", + RouteTableID: "rtb-002", + RouteAssocID: "rtbassoc-003", + SecurityGroupID: "sg-004", + AMIID: "ami-005", + InstanceIDs: []string{"i-1", "i-2", "i-3"}, + InstanceIDByRole: map[string]string{ + "primary": "i-1", + "worker1": "i-2", + "worker2": "i-3", + }, + PrimaryPublicIP: "1.2.3.4", + } +} + +func TestEmitOutput_GitHubOutput(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "GITHUB_OUTPUT") + + if err := emitOutput(newTestLogger(), path, "github-output", sampleIDs()); err != nil { + t.Fatalf("emitOutput: %v", err) + } + body, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + s := string(body) + for _, want := range []string{ + "vpc_id=vpc-abc", + "igw_id=igw-def", + "subnet_id=subnet-001", + "route_table_id=rtb-002", + "route_assoc_id=rtbassoc-003", + "security_group_id=sg-004", + "ami_id=ami-005", + "primary_public_ip=1.2.3.4", + "instance_ids=i-1,i-2,i-3", + "instance_id_primary=i-1", + "instance_id_worker1=i-2", + "instance_id_worker2=i-3", + } { + if !strings.Contains(s, want) { + t.Errorf("output missing %q\nfull output:\n%s", want, s) + } + } + + // instance_id_by_role must round-trip as valid JSON the action consumer + // can parse with fromJSON. Map ordering is non-deterministic; assert + // the three pairs by parsing. + for _, line := range strings.Split(s, "\n") { + const prefix = "instance_id_by_role=" + if !strings.HasPrefix(line, prefix) { + continue + } + var got map[string]string + if err := json.Unmarshal([]byte(line[len(prefix):]), &got); err != nil { + t.Fatalf("instance_id_by_role is not valid JSON: %v\n%s", err, line) + } + if got["primary"] != "i-1" || got["worker1"] != "i-2" || got["worker2"] != "i-3" { + t.Errorf("instance_id_by_role mis-parsed: %+v", got) + } + return + } + t.Errorf("instance_id_by_role line not found in output:\n%s", s) +} + +func TestEmitOutput_NonStandardRoles(t *testing.T) { + // Consumers with arbitrary role names (e.g. "primary,secondary") must + // be able to retrieve instance IDs via the JSON map output, since the + // hardcoded primary/worker1/worker2 outputs only cover the common case. + dir := t.TempDir() + path := filepath.Join(dir, "GITHUB_OUTPUT") + + ids := ResourceIDs{ + InstanceIDs: []string{"i-a", "i-b"}, + InstanceIDByRole: map[string]string{"primary": "i-a", "secondary": "i-b"}, + } + if err := emitOutput(newTestLogger(), path, "github-output", ids); err != nil { + t.Fatalf("emitOutput: %v", err) + } + body, _ := os.ReadFile(path) + s := string(body) + + for _, line := range strings.Split(s, "\n") { + const prefix = "instance_id_by_role=" + if !strings.HasPrefix(line, prefix) { + continue + } + var got map[string]string + if err := json.Unmarshal([]byte(line[len(prefix):]), &got); err != nil { + t.Fatalf("not valid JSON: %v", err) + } + if got["primary"] != "i-a" || got["secondary"] != "i-b" { + t.Errorf("non-standard roles not preserved: %+v", got) + } + return + } + t.Errorf("instance_id_by_role line missing for non-standard roles:\n%s", s) +} + +func TestEmitOutput_AutoFormatInfersFromPath(t *testing.T) { + dir := t.TempDir() + + envPath := filepath.Join(dir, "GITHUB_ENV") + if err := emitOutput(newTestLogger(), envPath, "auto", sampleIDs()); err != nil { + t.Fatalf("emitOutput: %v", err) + } + body, _ := os.ReadFile(envPath) + if !strings.Contains(string(body), "vpc_id=vpc-abc") { + t.Errorf("expected key=value format for GITHUB_ENV path, got:\n%s", body) + } + + // A plain path that doesn't match either marker should default to JSON. + jsonPath := filepath.Join(dir, "out.txt") + if err := emitOutput(newTestLogger(), jsonPath, "auto", sampleIDs()); err != nil { + t.Fatalf("emitOutput: %v", err) + } + body, _ = os.ReadFile(jsonPath) + var got ResourceIDs + if err := json.Unmarshal(body, &got); err != nil { + t.Errorf("expected JSON for unrecognized path, got: %s\nerr: %v", body, err) + } +} + +func TestEmitOutput_AppendsRatherThanOverwrites(t *testing.T) { + // GITHUB_OUTPUT and GITHUB_ENV are both append-mode files in real + // GitHub Actions. A run that calls emitOutput twice (e.g. provision + // success path + an additional metadata write) must accumulate both + // sets of pairs, not lose the first one. + dir := t.TempDir() + path := filepath.Join(dir, "GITHUB_OUTPUT") + + first := sampleIDs() + first.VPCID = "vpc-first" + if err := emitOutput(newTestLogger(), path, "github-output", first); err != nil { + t.Fatalf("first emit: %v", err) + } + second := sampleIDs() + second.VPCID = "vpc-second" + if err := emitOutput(newTestLogger(), path, "github-output", second); err != nil { + t.Fatalf("second emit: %v", err) + } + + body, _ := os.ReadFile(path) + s := string(body) + if !strings.Contains(s, "vpc_id=vpc-first") { + t.Errorf("expected first emission to be preserved, got:\n%s", s) + } + if !strings.Contains(s, "vpc_id=vpc-second") { + t.Errorf("expected second emission to be appended, got:\n%s", s) + } +} diff --git a/.github/actions/aws-test-infra/src/provision.go b/.github/actions/aws-test-infra/src/provision.go new file mode 100644 index 0000000..c28aee6 --- /dev/null +++ b/.github/actions/aws-test-infra/src/provision.go @@ -0,0 +1,522 @@ +package main + +import ( + "context" + "encoding/base64" + "errors" + "flag" + "fmt" + "log/slog" + "os" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" +) + +// defaultWaiterMaxWait caps how long the SDK's typed waiters +// (instance-running, instance-terminated) block. We set a ceiling to fail +// loudly instead of hanging. Instance-running is overridable per-call via +// the `-instance-running-timeout` flag for slow-boot edge cases. +const defaultWaiterMaxWait = 30 * time.Minute + +// ProvisionConfig is the parsed flag set for `provision`. +type ProvisionConfig struct { + Region string + RunID string + ConsumerTagKey string + ConsumerTagVal string + VPCCIDR string + SubnetCIDR string + AvailabilityZone string + + AMIID string + AMIOwner string + AMIFilter string + AMIArchitecture string + AMIVirtualizationType string + + SGName string + SGDescription string + IngressRules []IngressRule + + InstanceType string + InstanceProfile string + InstanceRoles []string + RootDevice string + VolumeSizeGB int32 + UserDataFile string + + SSMWaitTimeout time.Duration + SSMWaitInterval time.Duration + SkipSSMWait bool + InstanceRunningTimeout time.Duration + + OutputPath string + OutputFormat string +} + +// ResourceIDs is the result of provisioning. Every workflow consumer needs +// these IDs to drive subsequent steps and to clean up. +type ResourceIDs struct { + VPCID string `json:"vpc_id"` + IGWID string `json:"igw_id"` + SubnetID string `json:"subnet_id"` + RouteTableID string `json:"route_table_id"` + RouteAssocID string `json:"route_assoc_id"` + SecurityGroupID string `json:"security_group_id"` + AMIID string `json:"ami_id"` + InstanceIDs []string `json:"instance_ids"` + InstanceIDByRole map[string]string `json:"instance_id_by_role"` + PrimaryPublicIP string `json:"primary_public_ip"` +} + +func runProvision(ctx context.Context, logger *slog.Logger, name string, args []string) error { + fs := flag.NewFlagSet(name, flag.ContinueOnError) + cfg := ProvisionConfig{} + var ( + consumerTag string + instanceRoles string + ) + fs.StringVar(&cfg.Region, "region", "", "AWS region (required)") + fs.StringVar(&cfg.RunID, "run-id", "", "Unique run identifier; tagged on every resource as RunID (required)") + fs.StringVar(&consumerTag, "consumer-tag", "", "Consumer tag in KEY=VALUE form (e.g. SELinuxE2E=true) (required)") + fs.StringVar(&cfg.VPCCIDR, "vpc-cidr", "10.0.0.0/16", "VPC CIDR") + fs.StringVar(&cfg.SubnetCIDR, "subnet-cidr", "10.0.1.0/24", "Subnet CIDR") + fs.StringVar(&cfg.AvailabilityZone, "availability-zone", "", "AZ for the subnet; if empty, picks the first AZ in the region") + + fs.StringVar(&cfg.AMIID, "ami-id", "", "Use this exact AMI ID (skips AMI lookup)") + fs.StringVar(&cfg.AMIOwner, "ami-owner", "", "AMI owner (account ID or alias) for lookup") + fs.StringVar(&cfg.AMIFilter, "ami-filter", "", "AMI name filter for lookup (latest CreationDate wins)") + fs.StringVar(&cfg.AMIArchitecture, "ami-architecture", "", "Optional architecture filter for AMI lookup (e.g. x86_64, arm64). Empty means no filter.") + fs.StringVar(&cfg.AMIVirtualizationType, "ami-virtualization-type", "", "Optional virtualization-type filter for AMI lookup (e.g. hvm, paravirtual). Empty means no filter.") + + fs.StringVar(&cfg.SGName, "sg-name", "", "Security group name (required)") + fs.StringVar(&cfg.SGDescription, "sg-description", "", "Security group description") + rules := ingressFlag{rules: &cfg.IngressRules} + fs.Var(&rules, "ingress", "Ingress rule in protocol:fromPort:toPort:cidr form; repeatable") + + fs.StringVar(&cfg.InstanceType, "instance-type", "m5.xlarge", "EC2 instance type") + fs.StringVar(&cfg.InstanceProfile, "instance-profile", "", "IAM instance profile name") + fs.StringVar(&instanceRoles, "instance-roles", "primary,worker1,worker2", "Comma-separated role labels (one instance per role)") + fs.StringVar(&cfg.RootDevice, "root-device", "/dev/sda1", "Root block-device name (e.g. /dev/sda1 or /dev/xvda)") + var volumeSizeGB int + fs.IntVar(&volumeSizeGB, "volume-size-gb", 100, "Root volume size in GB") + fs.StringVar(&cfg.UserDataFile, "user-data-file", "", "Path to a file with raw user-data; the binary base64-encodes it before passing to RunInstances") + + fs.DurationVar(&cfg.SSMWaitTimeout, "ssm-wait-timeout", 5*time.Minute, "How long to wait for all SSM agents to register") + fs.DurationVar(&cfg.SSMWaitInterval, "ssm-wait-interval", 10*time.Second, "Polling interval for SSM agent registration") + fs.BoolVar(&cfg.SkipSSMWait, "skip-ssm-wait", false, "Skip waiting for SSM agents") + fs.DurationVar(&cfg.InstanceRunningTimeout, "instance-running-timeout", defaultWaiterMaxWait, "Max wait for all instances to reach running state") + + fs.StringVar(&cfg.OutputPath, "output", "", "Output destination; empty means stdout. Set to $GITHUB_OUTPUT or $GITHUB_ENV to feed into Actions") + fs.StringVar(&cfg.OutputFormat, "output-format", "auto", "auto | github-output | github-env | json") + + if err := fs.Parse(args); err != nil { + return fmt.Errorf("parse provision flags: %w", err) + } + cfg.VolumeSizeGB = int32(volumeSizeGB) + + if err := finalizeProvisionConfig(&cfg, consumerTag, instanceRoles); err != nil { + return err + } + + awsCfg, err := loadAWSConfig(ctx, cfg.Region) + if err != nil { + return fmt.Errorf("load aws config: %w", err) + } + ec2Client := ec2.NewFromConfig(awsCfg) + ssmClient := ssm.NewFromConfig(awsCfg) + waiter := &ec2WaiterAdapter{ + client: ec2Client, + instanceRunningTimeout: cfg.InstanceRunningTimeout, + } + + ids, err := Provision(ctx, logger, ec2Client, ssmClient, waiter, cfg) + if err != nil { + // Provision returns whatever it managed to create so the caller can + // pipe it into cleanup. We always emit so the action can capture IDs + // for cleanup even on failure. + _ = emitOutput(logger, cfg.OutputPath, cfg.OutputFormat, ids) + return err + } + return emitOutput(logger, cfg.OutputPath, cfg.OutputFormat, ids) +} + +// finalizeProvisionConfig validates required fields and parses raw form +// values (consumerTag, instanceRolesCSV) into cfg's derived fields. Pure +// function — no AWS, no I/O — so it's directly testable. +func finalizeProvisionConfig(cfg *ProvisionConfig, consumerTag, instanceRolesCSV string) error { + if cfg.Region == "" { + return errors.New("-region is required") + } + if cfg.RunID == "" { + return errors.New("-run-id is required") + } + if cfg.SGName == "" { + return errors.New("-sg-name is required") + } + if consumerTag == "" { + return errors.New("-consumer-tag is required (KEY=VALUE)") + } + eq := strings.IndexByte(consumerTag, '=') + if eq <= 0 || eq == len(consumerTag)-1 { + return fmt.Errorf("-consumer-tag must be KEY=VALUE, got %q", consumerTag) + } + cfg.ConsumerTagKey = consumerTag[:eq] + cfg.ConsumerTagVal = consumerTag[eq+1:] + + if cfg.AMIID == "" && (cfg.AMIOwner == "" || cfg.AMIFilter == "") { + return errors.New("-ami-id, OR both of -ami-owner and -ami-filter, are required") + } + cfg.InstanceRoles = splitCSV(instanceRolesCSV) + if len(cfg.InstanceRoles) == 0 { + return errors.New("-instance-roles must contain at least one role") + } + if cfg.VolumeSizeGB <= 0 { + return errors.New("-volume-size-gb must be > 0") + } + return nil +} + +// Provision is the testable core of the provision command. It mutates +// nothing on the host, only AWS via the supplied EC2/SSM clients. +func Provision( + ctx context.Context, + logger *slog.Logger, + c EC2API, + s SSMAPI, + waiter EC2Waiter, + cfg ProvisionConfig, +) (ResourceIDs, error) { + ids := ResourceIDs{InstanceIDByRole: map[string]string{}} + + // VPC + vpcOut, err := c.CreateVpc(ctx, &ec2.CreateVpcInput{ + CidrBlock: aws.String(cfg.VPCCIDR), + TagSpecifications: tagSpec(types.ResourceTypeVpc, cfg, ""), + }) + if err != nil { + return ids, fmt.Errorf("create vpc: %w", err) + } + ids.VPCID = aws.ToString(vpcOut.Vpc.VpcId) + logger.Info("created vpc", "vpc_id", ids.VPCID) + + // VPC attributes — DNS support + hostnames (so the public DNS name is + // resolvable, which the existing workflows depend on). + if _, err := c.ModifyVpcAttribute(ctx, &ec2.ModifyVpcAttributeInput{ + VpcId: aws.String(ids.VPCID), + EnableDnsSupport: &types.AttributeBooleanValue{Value: aws.Bool(true)}, + }); err != nil { + return ids, fmt.Errorf("enable dns support: %w", err) + } + if _, err := c.ModifyVpcAttribute(ctx, &ec2.ModifyVpcAttributeInput{ + VpcId: aws.String(ids.VPCID), + EnableDnsHostnames: &types.AttributeBooleanValue{Value: aws.Bool(true)}, + }); err != nil { + return ids, fmt.Errorf("enable dns hostnames: %w", err) + } + + // Internet gateway + igwOut, err := c.CreateInternetGateway(ctx, &ec2.CreateInternetGatewayInput{ + TagSpecifications: tagSpec(types.ResourceTypeInternetGateway, cfg, ""), + }) + if err != nil { + return ids, fmt.Errorf("create internet gateway: %w", err) + } + ids.IGWID = aws.ToString(igwOut.InternetGateway.InternetGatewayId) + logger.Info("created igw", "igw_id", ids.IGWID) + + if _, err := c.AttachInternetGateway(ctx, &ec2.AttachInternetGatewayInput{ + InternetGatewayId: aws.String(ids.IGWID), + VpcId: aws.String(ids.VPCID), + }); err != nil { + return ids, fmt.Errorf("attach internet gateway: %w", err) + } + + // AZ — auto-pick the first AZ if not given + az := cfg.AvailabilityZone + if az == "" { + azOut, err := c.DescribeAvailabilityZones(ctx, &ec2.DescribeAvailabilityZonesInput{}) + if err != nil { + return ids, fmt.Errorf("describe availability zones: %w", err) + } + if len(azOut.AvailabilityZones) == 0 { + return ids, errors.New("no availability zones returned for region") + } + az = aws.ToString(azOut.AvailabilityZones[0].ZoneName) + } + + // Subnet + subnetOut, err := c.CreateSubnet(ctx, &ec2.CreateSubnetInput{ + VpcId: aws.String(ids.VPCID), + CidrBlock: aws.String(cfg.SubnetCIDR), + AvailabilityZone: aws.String(az), + TagSpecifications: tagSpec(types.ResourceTypeSubnet, cfg, ""), + }) + if err != nil { + return ids, fmt.Errorf("create subnet: %w", err) + } + ids.SubnetID = aws.ToString(subnetOut.Subnet.SubnetId) + logger.Info("created subnet", "subnet_id", ids.SubnetID, "az", az) + + if _, err := c.ModifySubnetAttribute(ctx, &ec2.ModifySubnetAttributeInput{ + SubnetId: aws.String(ids.SubnetID), + MapPublicIpOnLaunch: &types.AttributeBooleanValue{Value: aws.Bool(true)}, + }); err != nil { + return ids, fmt.Errorf("modify subnet attribute (map-public-ip): %w", err) + } + + // Route table + default route + association + rtOut, err := c.CreateRouteTable(ctx, &ec2.CreateRouteTableInput{ + VpcId: aws.String(ids.VPCID), + TagSpecifications: tagSpec(types.ResourceTypeRouteTable, cfg, ""), + }) + if err != nil { + return ids, fmt.Errorf("create route table: %w", err) + } + ids.RouteTableID = aws.ToString(rtOut.RouteTable.RouteTableId) + + if _, err := c.CreateRoute(ctx, &ec2.CreateRouteInput{ + RouteTableId: aws.String(ids.RouteTableID), + DestinationCidrBlock: aws.String("0.0.0.0/0"), + GatewayId: aws.String(ids.IGWID), + }); err != nil { + return ids, fmt.Errorf("create route: %w", err) + } + + assocOut, err := c.AssociateRouteTable(ctx, &ec2.AssociateRouteTableInput{ + RouteTableId: aws.String(ids.RouteTableID), + SubnetId: aws.String(ids.SubnetID), + }) + if err != nil { + return ids, fmt.Errorf("associate route table: %w", err) + } + ids.RouteAssocID = aws.ToString(assocOut.AssociationId) + logger.Info("created route table", "rt_id", ids.RouteTableID, "assoc_id", ids.RouteAssocID) + + // Security group + ingress rules + desc := cfg.SGDescription + if desc == "" { + desc = fmt.Sprintf("aws-test-infra %s", cfg.RunID) + } + sgOut, err := c.CreateSecurityGroup(ctx, &ec2.CreateSecurityGroupInput{ + GroupName: aws.String(cfg.SGName), + Description: aws.String(desc), + VpcId: aws.String(ids.VPCID), + TagSpecifications: tagSpec(types.ResourceTypeSecurityGroup, cfg, ""), + }) + if err != nil { + return ids, fmt.Errorf("create security group: %w", err) + } + ids.SecurityGroupID = aws.ToString(sgOut.GroupId) + logger.Info("created security group", "sg_id", ids.SecurityGroupID) + + for _, rule := range cfg.IngressRules { + ipPerm := types.IpPermission{ + IpProtocol: aws.String(rule.Protocol), + FromPort: aws.Int32(rule.FromPort), + ToPort: aws.Int32(rule.ToPort), + IpRanges: []types.IpRange{{CidrIp: aws.String(rule.CIDR)}}, + } + if _, err := c.AuthorizeSecurityGroupIngress(ctx, &ec2.AuthorizeSecurityGroupIngressInput{ + GroupId: aws.String(ids.SecurityGroupID), + IpPermissions: []types.IpPermission{ipPerm}, + }); err != nil { + return ids, fmt.Errorf("authorize ingress %s:%d:%d:%s: %w", rule.Protocol, rule.FromPort, rule.ToPort, rule.CIDR, err) + } + } + + // Resolve AMI if not given + amiID := cfg.AMIID + if amiID == "" { + amiID, err = resolveAMI(ctx, c, cfg.AMIOwner, cfg.AMIFilter, cfg.AMIArchitecture, cfg.AMIVirtualizationType) + if err != nil { + return ids, err + } + logger.Info("resolved ami", "ami_id", amiID, "filter", cfg.AMIFilter) + } + ids.AMIID = amiID + + // User data (optional) + var userDataB64 *string + if cfg.UserDataFile != "" { + raw, err := os.ReadFile(cfg.UserDataFile) + if err != nil { + return ids, fmt.Errorf("read user-data file: %w", err) + } + b64 := base64.StdEncoding.EncodeToString(raw) + userDataB64 = aws.String(b64) + } + + // Launch instances per role + for _, role := range cfg.InstanceRoles { + instOut, err := c.RunInstances(ctx, &ec2.RunInstancesInput{ + ImageId: aws.String(amiID), + InstanceType: types.InstanceType(cfg.InstanceType), + MinCount: aws.Int32(1), + MaxCount: aws.Int32(1), + SubnetId: aws.String(ids.SubnetID), + SecurityGroupIds: []string{ids.SecurityGroupID}, + IamInstanceProfile: instanceProfileSpec(cfg.InstanceProfile), + BlockDeviceMappings: []types.BlockDeviceMapping{{ + DeviceName: aws.String(cfg.RootDevice), + Ebs: &types.EbsBlockDevice{ + VolumeSize: aws.Int32(cfg.VolumeSizeGB), + VolumeType: types.VolumeTypeGp3, + DeleteOnTermination: aws.Bool(true), + }, + }}, + UserData: userDataB64, + TagSpecifications: tagSpec(types.ResourceTypeInstance, cfg, role), + }) + if err != nil { + return ids, fmt.Errorf("run instances (role=%s): %w", role, err) + } + if len(instOut.Instances) == 0 { + return ids, fmt.Errorf("run instances (role=%s) returned no instances", role) + } + instID := aws.ToString(instOut.Instances[0].InstanceId) + ids.InstanceIDs = append(ids.InstanceIDs, instID) + ids.InstanceIDByRole[role] = instID + logger.Info("launched instance", "role", role, "instance_id", instID) + } + + // Wait for instance-running on all + if err := waiter.WaitInstanceRunning(ctx, ids.InstanceIDs); err != nil { + return ids, fmt.Errorf("wait instance-running: %w", err) + } + + // Pull primary public IP (the existing workflows use the public IP of + // the first instance — by convention, "primary" — for runner→primary + // and worker→primary connectivity). + if len(ids.InstanceIDs) > 0 { + descOut, err := c.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + InstanceIds: []string{ids.InstanceIDs[0]}, + }) + if err != nil { + return ids, fmt.Errorf("describe primary instance: %w", err) + } + if len(descOut.Reservations) > 0 && len(descOut.Reservations[0].Instances) > 0 { + ids.PrimaryPublicIP = aws.ToString(descOut.Reservations[0].Instances[0].PublicIpAddress) + } + } + + // SSM agent registration wait + if !cfg.SkipSSMWait { + if err := waitSSMOnline(ctx, logger, s, ids.InstanceIDs, cfg.SSMWaitTimeout, cfg.SSMWaitInterval); err != nil { + return ids, err + } + } + + return ids, nil +} + +func resolveAMI(ctx context.Context, c EC2API, owner, filter, architecture, virtualizationType string) (string, error) { + filters := []types.Filter{ + {Name: aws.String("name"), Values: []string{filter}}, + {Name: aws.String("state"), Values: []string{"available"}}, + } + if architecture != "" { + filters = append(filters, types.Filter{Name: aws.String("architecture"), Values: []string{architecture}}) + } + if virtualizationType != "" { + filters = append(filters, types.Filter{Name: aws.String("virtualization-type"), Values: []string{virtualizationType}}) + } + out, err := c.DescribeImages(ctx, &ec2.DescribeImagesInput{ + Owners: []string{owner}, + Filters: filters, + }) + if err != nil { + return "", fmt.Errorf("describe images: %w", err) + } + if len(out.Images) == 0 { + return "", fmt.Errorf("no AMIs found for owner=%s filter=%s", owner, filter) + } + // Pick the latest by CreationDate. + latestIdx := 0 + for i := 1; i < len(out.Images); i++ { + if aws.ToString(out.Images[i].CreationDate) > aws.ToString(out.Images[latestIdx].CreationDate) { + latestIdx = i + } + } + return aws.ToString(out.Images[latestIdx].ImageId), nil +} + +func waitSSMOnline( + ctx context.Context, + logger *slog.Logger, + s SSMAPI, + instanceIDs []string, + timeout time.Duration, + interval time.Duration, +) error { + deadline := time.Now().Add(timeout) + for { + out, err := s.DescribeInstanceInformation(ctx, &ssm.DescribeInstanceInformationInput{ + Filters: []ssmtypes.InstanceInformationStringFilter{ + {Key: aws.String("InstanceIds"), Values: instanceIDs}, + }, + }) + if err == nil { + online := 0 + for _, info := range out.InstanceInformationList { + if info.PingStatus == ssmtypes.PingStatusOnline { + online++ + } + } + logger.Info("ssm wait", "online", online, "total", len(instanceIDs)) + if online == len(instanceIDs) { + return nil + } + } else { + logger.Warn("describe-instance-information errored, retrying", "err", err) + } + if time.Now().After(deadline) { + return fmt.Errorf("timed out waiting for SSM agents (%d instances)", len(instanceIDs)) + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(interval): + } + } +} + +func tagSpec(rt types.ResourceType, cfg ProvisionConfig, role string) []types.TagSpecification { + tags := []types.Tag{ + {Key: aws.String(cfg.ConsumerTagKey), Value: aws.String(cfg.ConsumerTagVal)}, + {Key: aws.String("RunID"), Value: aws.String(cfg.RunID)}, + } + if role != "" { + tags = append(tags, types.Tag{Key: aws.String("Role"), Value: aws.String(role)}) + } + return []types.TagSpecification{{ResourceType: rt, Tags: tags}} +} + +func instanceProfileSpec(name string) *types.IamInstanceProfileSpecification { + if name == "" { + return nil + } + return &types.IamInstanceProfileSpecification{Name: aws.String(name)} +} + +func splitCSV(s string) []string { + if s == "" { + return nil + } + parts := strings.Split(s, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + return out +} + diff --git a/.github/actions/aws-test-infra/src/provision_test.go b/.github/actions/aws-test-infra/src/provision_test.go new file mode 100644 index 0000000..fe3166f --- /dev/null +++ b/.github/actions/aws-test-infra/src/provision_test.go @@ -0,0 +1,462 @@ +package main + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" +) + +func baseProvisionConfig() ProvisionConfig { + return ProvisionConfig{ + Region: "us-west-2", + RunID: "run-42", + ConsumerTagKey: "SELinuxE2E", + ConsumerTagVal: "true", + VPCCIDR: "10.0.0.0/16", + SubnetCIDR: "10.0.1.0/24", + AMIOwner: "099720109477", + AMIFilter: "ubuntu/images/hvm-ssd*/ubuntu-jammy-22.04-amd64-server-*", + SGName: "selinux-e2e-42", + IngressRules: []IngressRule{ + {Protocol: "-1", FromPort: -1, ToPort: -1, CIDR: "10.0.0.0/16"}, + {Protocol: "tcp", FromPort: 8443, ToPort: 8443, CIDR: "0.0.0.0/0"}, + }, + InstanceType: "m5.xlarge", + InstanceProfile: "e2e-test-executor", + InstanceRoles: []string{"primary", "worker1", "worker2"}, + RootDevice: "/dev/sda1", + VolumeSizeGB: 100, + SkipSSMWait: true, + } +} + +func TestProvision_HappyPath_Ordering(t *testing.T) { + c := &fakeEC2{} + s := &fakeSSM{online: 3} + w := &fakeWaiter{} + + ids, err := Provision(context.Background(), newTestLogger(), c, s, w, baseProvisionConfig()) + if err != nil { + t.Fatalf("Provision: %v", err) + } + + want := []string{ + "CreateVpc", + "ModifyVpcAttribute", // dns-support + "ModifyVpcAttribute", // dns-hostnames + "CreateInternetGateway", + "AttachInternetGateway", + "DescribeAvailabilityZones", + "CreateSubnet", + "ModifySubnetAttribute", // map-public-ip + "CreateRouteTable", + "CreateRoute", + "AssociateRouteTable", + "CreateSecurityGroup", + "AuthorizeSecurityGroupIngress", // rule 1 + "AuthorizeSecurityGroupIngress", // rule 2 + "DescribeImages", + "RunInstances", // primary + "RunInstances", // worker1 + "RunInstances", // worker2 + "DescribeInstances", // primary public IP + } + seq := methodSequence(c.calls) + if err := requireOrdering(seq, want); err != nil { + t.Fatal(err) + } + + // Strict precedes-checks for the dependency-critical pairs in + // provisioning. These catch insertion bugs that requireOrdering + // would silently accept. + // + // IGW must be attached before subnet creation; otherwise the + // implicit route we add below would have no working gateway. + if err := requireBefore(seq, "AttachInternetGateway", "CreateSubnet"); err != nil { + t.Errorf("AttachInternetGateway must precede CreateSubnet: %v", err) + } + // Route table must exist before associate. CreateRoute and + // AssociateRouteTable both need the route table ID. + if err := requireBefore(seq, "CreateRouteTable", "AssociateRouteTable"); err != nil { + t.Errorf("CreateRouteTable must precede AssociateRouteTable: %v", err) + } + // Security group must be created before any ingress authorization. + if err := requireBefore(seq, "CreateSecurityGroup", "AuthorizeSecurityGroupIngress"); err != nil { + t.Errorf("CreateSecurityGroup must precede AuthorizeSecurityGroupIngress: %v", err) + } + // AMI lookup must complete before instances are launched. + if err := requireBefore(seq, "DescribeImages", "RunInstances"); err != nil { + t.Errorf("DescribeImages must precede RunInstances: %v", err) + } + + if ids.VPCID == "" || ids.IGWID == "" || ids.SubnetID == "" || ids.RouteTableID == "" || ids.SecurityGroupID == "" { + t.Errorf("ResourceIDs has empty fields: %+v", ids) + } + if got, want := len(ids.InstanceIDs), 3; got != want { + t.Errorf("InstanceIDs count = %d, want %d (%v)", got, want, ids.InstanceIDs) + } + if ids.AMIID != "ami-newest" { + t.Errorf("AMIID = %q, want ami-newest (latest CreationDate)", ids.AMIID) + } + if ids.PrimaryPublicIP != "203.0.113.1" { + t.Errorf("PrimaryPublicIP = %q, want 203.0.113.1", ids.PrimaryPublicIP) + } + if got := ids.InstanceIDByRole["primary"]; got != "i-primary" { + t.Errorf("InstanceIDByRole[primary] = %q, want i-primary", got) + } +} + +func TestProvision_TagsAppliedToEveryResource(t *testing.T) { + c := &fakeEC2{} + s := &fakeSSM{online: 3} + w := &fakeWaiter{} + + if _, err := Provision(context.Background(), newTestLogger(), c, s, w, baseProvisionConfig()); err != nil { + t.Fatalf("Provision: %v", err) + } + + wantResourceTypes := []types.ResourceType{ + types.ResourceTypeVpc, + types.ResourceTypeInternetGateway, + types.ResourceTypeSubnet, + types.ResourceTypeRouteTable, + types.ResourceTypeSecurityGroup, + types.ResourceTypeInstance, + } + + for _, rt := range wantResourceTypes { + found := false + for _, call := range c.calls { + for _, ts := range call.TagSpec { + if ts.ResourceType != rt { + continue + } + keyVals := map[string]string{} + for _, t := range ts.Tags { + keyVals[aws.ToString(t.Key)] = aws.ToString(t.Value) + } + if keyVals["SELinuxE2E"] != "true" { + t.Errorf("%s: missing/wrong SELinuxE2E tag (got %q)", rt, keyVals["SELinuxE2E"]) + } + if keyVals["RunID"] != "run-42" { + t.Errorf("%s: missing/wrong RunID tag (got %q)", rt, keyVals["RunID"]) + } + found = true + } + } + if !found { + t.Errorf("no call tagged a %s resource — provision must tag every created resource", rt) + } + } +} + +func TestProvision_InstanceRoleTags(t *testing.T) { + c := &fakeEC2{} + s := &fakeSSM{online: 3} + w := &fakeWaiter{} + + if _, err := Provision(context.Background(), newTestLogger(), c, s, w, baseProvisionConfig()); err != nil { + t.Fatalf("Provision: %v", err) + } + + rolesSeen := map[string]bool{} + for _, call := range c.calls { + if call.Method != "RunInstances" { + continue + } + for _, ts := range call.TagSpec { + if ts.ResourceType != types.ResourceTypeInstance { + continue + } + for _, tag := range ts.Tags { + if aws.ToString(tag.Key) == "Role" { + rolesSeen[aws.ToString(tag.Value)] = true + } + } + } + } + for _, want := range []string{"primary", "worker1", "worker2"} { + if !rolesSeen[want] { + t.Errorf("no RunInstances call tagged Role=%s — workflow consumers depend on this for cleanup", want) + } + } +} + +func TestProvision_AMIFilters(t *testing.T) { + // The DescribeImages filter set must include architecture and + // virtualization-type ONLY when the caller sets them. Defaults to + // x86_64+hvm in action.yml to match the original Bash; binary + // defaults to empty to allow arm64 lookups via the -ami-architecture + // flag. + tests := []struct { + name string + architecture string + virtualizationType string + wantArch string // empty = filter must be absent + wantVirt string + }{ + {name: "both set (x86_64 + hvm — Bash safety net)", architecture: "x86_64", virtualizationType: "hvm", wantArch: "x86_64", wantVirt: "hvm"}, + {name: "both empty (binary default — allows any)", architecture: "", virtualizationType: "", wantArch: "", wantVirt: ""}, + {name: "architecture-only (arm64 lookup)", architecture: "arm64", virtualizationType: "", wantArch: "arm64", wantVirt: ""}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + c := &fakeEC2{} + cfg := baseProvisionConfig() + cfg.AMIArchitecture = tt.architecture + cfg.AMIVirtualizationType = tt.virtualizationType + + if _, err := Provision(context.Background(), newTestLogger(), c, &fakeSSM{}, &fakeWaiter{}, cfg); err != nil { + t.Fatalf("Provision: %v", err) + } + + var got *ec2.DescribeImagesInput + for _, call := range c.calls { + if call.Method == "DescribeImages" { + got = call.Input.(*ec2.DescribeImagesInput) + break + } + } + if got == nil { + t.Fatal("no DescribeImages call recorded") + } + gotArch, gotVirt := "", "" + for _, f := range got.Filters { + switch aws.ToString(f.Name) { + case "architecture": + if len(f.Values) == 1 { + gotArch = f.Values[0] + } + case "virtualization-type": + if len(f.Values) == 1 { + gotVirt = f.Values[0] + } + } + } + if gotArch != tt.wantArch { + t.Errorf("architecture filter = %q, want %q", gotArch, tt.wantArch) + } + if gotVirt != tt.wantVirt { + t.Errorf("virtualization-type filter = %q, want %q", gotVirt, tt.wantVirt) + } + }) + } +} + +func TestProvision_AMILookupPicksNewestByCreationDate(t *testing.T) { + c := &fakeEC2{} + s := &fakeSSM{online: 0} + w := &fakeWaiter{} + cfg := baseProvisionConfig() + + ids, err := Provision(context.Background(), newTestLogger(), c, s, w, cfg) + if err != nil { + t.Fatalf("Provision: %v", err) + } + // fakeEC2.DescribeImages returns three images with mixed dates; + // resolveAMI should pick "ami-newest" (2026-01-01). + if ids.AMIID != "ami-newest" { + t.Errorf("AMIID = %q, want ami-newest", ids.AMIID) + } +} + +func TestProvision_FailureAtEachStage(t *testing.T) { + // Five representative failure points that prove the load-bearing + // "return whatever IDs you collected, so cleanup can tear them down" + // contract holds across the whole provision flow. Earlier runs of + // this test had a row per AWS call (13 total); these 5 cover early / + // post-VPC / mid-build / AMI-lookup / late-stage without the + // repetition. + tests := []struct { + name string + failOn string + errSubstring string + expectVPCID bool + expectIGWID bool + expectSubnetID bool + expectRTID bool + expectSGID bool + expectAMIID bool + }{ + { + name: "CreateVpc fails (very-early — nothing collected)", + failOn: "CreateVpc", + errSubstring: "create vpc", + }, + { + name: "AttachInternetGateway fails (post-VPC, post-IGW)", + failOn: "AttachInternetGateway", + errSubstring: "attach", + expectVPCID: true, + expectIGWID: true, + }, + { + name: "CreateSecurityGroup fails (mid-build)", + failOn: "CreateSecurityGroup", + errSubstring: "security group", + expectVPCID: true, + expectIGWID: true, + expectSubnetID: true, + expectRTID: true, + }, + { + name: "DescribeImages fails (AMI lookup — all infra collected, no AMI)", + failOn: "DescribeImages", + errSubstring: "describe images", + expectVPCID: true, + expectIGWID: true, + expectSubnetID: true, + expectRTID: true, + expectSGID: true, + }, + { + name: "RunInstances fails (late — everything collected except instance IDs)", + failOn: "RunInstances", + errSubstring: "run instances", + expectVPCID: true, + expectIGWID: true, + expectSubnetID: true, + expectRTID: true, + expectSGID: true, + expectAMIID: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + c := &fakeEC2{failOn: map[string]error{tt.failOn: errStaged}} + s := &fakeSSM{online: 0} + w := &fakeWaiter{} + + ids, err := Provision(context.Background(), newTestLogger(), c, s, w, baseProvisionConfig()) + if err == nil { + t.Fatalf("expected error from staged %s failure", tt.failOn) + } + if !strings.Contains(strings.ToLower(err.Error()), tt.errSubstring) { + t.Errorf("error %q does not mention %q (lower-cased) — debugging this failure in CI will be harder", err.Error(), tt.errSubstring) + } + // The "return whatever you collected so far" contract: cleanup + // can only tear down what's reflected in ResourceIDs. + gotIDs := map[string]bool{ + "VPCID": ids.VPCID != "", + "IGWID": ids.IGWID != "", + "SubnetID": ids.SubnetID != "", + "RouteTableID": ids.RouteTableID != "", + "SecurityGroupID": ids.SecurityGroupID != "", + "AMIID": ids.AMIID != "", + } + wantIDs := map[string]bool{ + "VPCID": tt.expectVPCID, + "IGWID": tt.expectIGWID, + "SubnetID": tt.expectSubnetID, + "RouteTableID": tt.expectRTID, + "SecurityGroupID": tt.expectSGID, + "AMIID": tt.expectAMIID, + } + for k, want := range wantIDs { + if gotIDs[k] != want { + t.Errorf("%s populated=%v, want %v", k, gotIDs[k], want) + } + } + }) + } +} + +func TestProvision_SSMWaitSucceedsWhenAllOnline(t *testing.T) { + c := &fakeEC2{} + s := &fakeSSM{online: 3} + w := &fakeWaiter{} + cfg := baseProvisionConfig() + cfg.SkipSSMWait = false + cfg.SSMWaitTimeout = time.Second + cfg.SSMWaitInterval = 10 * time.Millisecond + + if _, err := Provision(context.Background(), newTestLogger(), c, s, w, cfg); err != nil { + t.Fatalf("Provision: %v", err) + } + if s.calls < 1 { + t.Errorf("expected at least one SSM call, got %d", s.calls) + } +} + +func TestProvision_SSMWaitTimeout(t *testing.T) { + c := &fakeEC2{} + s := &fakeSSM{online: 1} // only 1 of 3 online — never satisfies + w := &fakeWaiter{} + cfg := baseProvisionConfig() + cfg.SkipSSMWait = false + cfg.SSMWaitTimeout = 50 * time.Millisecond + cfg.SSMWaitInterval = 10 * time.Millisecond + + _, err := Provision(context.Background(), newTestLogger(), c, s, w, cfg) + if err == nil { + t.Fatal("expected timeout error") + } + if !strings.Contains(err.Error(), "timed out") { + t.Errorf("expected timeout error, got: %v", err) + } +} + +func TestProvision_IngressRuleEncoding(t *testing.T) { + // Sanity-check that the SDK call shapes match what the existing Bash + // produced — protocol, port range, and CIDR must round-trip exactly. + c := &fakeEC2{} + s := &fakeSSM{online: 0} + w := &fakeWaiter{} + cfg := baseProvisionConfig() + cfg.SkipSSMWait = true + cfg.IngressRules = []IngressRule{ + {Protocol: "-1", FromPort: -1, ToPort: -1, CIDR: "10.0.0.0/16"}, + {Protocol: "tcp", FromPort: 30000, ToPort: 32767, CIDR: "0.0.0.0/0"}, + {Protocol: "icmp", FromPort: -1, ToPort: -1, CIDR: "10.0.0.0/16"}, + } + + if _, err := Provision(context.Background(), newTestLogger(), c, s, w, cfg); err != nil { + t.Fatalf("Provision: %v", err) + } + + type encoded struct { + Protocol string + From int32 + To int32 + CIDR string + } + var got []encoded + for _, call := range c.calls { + if call.Method != "AuthorizeSecurityGroupIngress" { + continue + } + in := call.Input.(*ec2.AuthorizeSecurityGroupIngressInput) + for _, p := range in.IpPermissions { + cidr := "" + if len(p.IpRanges) > 0 { + cidr = aws.ToString(p.IpRanges[0].CidrIp) + } + got = append(got, encoded{ + Protocol: aws.ToString(p.IpProtocol), + From: aws.ToInt32(p.FromPort), + To: aws.ToInt32(p.ToPort), + CIDR: cidr, + }) + } + } + want := []encoded{ + {Protocol: "-1", From: -1, To: -1, CIDR: "10.0.0.0/16"}, + {Protocol: "tcp", From: 30000, To: 32767, CIDR: "0.0.0.0/0"}, + {Protocol: "icmp", From: -1, To: -1, CIDR: "10.0.0.0/16"}, + } + if len(got) != len(want) { + t.Fatalf("got %d ingress calls, want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("ingress[%d] = %+v, want %+v", i, got[i], want[i]) + } + } +} diff --git a/.github/actions/aws-test-infra/src/validation_test.go b/.github/actions/aws-test-infra/src/validation_test.go new file mode 100644 index 0000000..a8ae37c --- /dev/null +++ b/.github/actions/aws-test-infra/src/validation_test.go @@ -0,0 +1,177 @@ +package main + +import ( + "strings" + "testing" +) + +// validProvisionConfig returns a config that passes finalizeProvisionConfig +// — tests below mutate one field at a time to verify each validation +// branch. +func validProvisionConfig() ProvisionConfig { + return ProvisionConfig{ + Region: "us-west-2", + RunID: "run-42", + SGName: "sg-test", + AMIOwner: "099720109477", + AMIFilter: "ubuntu-jammy*", + VolumeSizeGB: 100, + } +} + +func TestFinalizeProvisionConfig_HappyPath(t *testing.T) { + cfg := validProvisionConfig() + if err := finalizeProvisionConfig(&cfg, "SELinuxE2E=true", "primary,worker1,worker2"); err != nil { + t.Fatalf("happy path errored: %v", err) + } + if cfg.ConsumerTagKey != "SELinuxE2E" || cfg.ConsumerTagVal != "true" { + t.Errorf("consumer-tag mis-parsed: key=%q val=%q", cfg.ConsumerTagKey, cfg.ConsumerTagVal) + } + if len(cfg.InstanceRoles) != 3 || cfg.InstanceRoles[0] != "primary" { + t.Errorf("InstanceRoles mis-parsed: %v", cfg.InstanceRoles) + } +} + +func TestFinalizeProvisionConfig_RejectsBadInput(t *testing.T) { + tests := []struct { + name string + mutate func(*ProvisionConfig) + consumerTag string + roles string + wantErrSub string + }{ + { + name: "missing region", + mutate: func(c *ProvisionConfig) { c.Region = "" }, + consumerTag: "SELinuxE2E=true", + roles: "primary", + wantErrSub: "-region is required", + }, + { + name: "missing run-id", + mutate: func(c *ProvisionConfig) { c.RunID = "" }, + consumerTag: "SELinuxE2E=true", + roles: "primary", + wantErrSub: "-run-id is required", + }, + { + name: "missing sg-name", + mutate: func(c *ProvisionConfig) { c.SGName = "" }, + consumerTag: "SELinuxE2E=true", + roles: "primary", + wantErrSub: "-sg-name is required", + }, + { + name: "empty consumer-tag", + mutate: func(c *ProvisionConfig) {}, + consumerTag: "", + roles: "primary", + wantErrSub: "-consumer-tag is required", + }, + { + // Covers the three sub-cases of the same check + // (`eq <= 0 || eq == len(consumerTag)-1`): missing-equals, + // empty-key, empty-value all hit the same branch. + name: "consumer-tag malformed", + mutate: func(c *ProvisionConfig) {}, + consumerTag: "SELinuxE2Etrue", + roles: "primary", + wantErrSub: "must be KEY=VALUE", + }, + { + name: "no AMI source", + mutate: func(c *ProvisionConfig) { + c.AMIID = "" + c.AMIOwner = "" + c.AMIFilter = "" + }, + consumerTag: "SELinuxE2E=true", + roles: "primary", + wantErrSub: "-ami-id, OR both of -ami-owner and -ami-filter", + }, + { + // Both empty and whitespace-only inputs hit the same + // `len(cfg.InstanceRoles) == 0` check after splitCSV. + name: "instance-roles produces no roles", + mutate: func(c *ProvisionConfig) {}, + consumerTag: "SELinuxE2E=true", + roles: " , ,", + wantErrSub: "instance-roles must contain at least one role", + }, + { + name: "non-positive volume-size", + mutate: func(c *ProvisionConfig) { c.VolumeSizeGB = 0 }, + consumerTag: "SELinuxE2E=true", + roles: "primary", + wantErrSub: "-volume-size-gb must be > 0", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + cfg := validProvisionConfig() + tt.mutate(&cfg) + err := finalizeProvisionConfig(&cfg, tt.consumerTag, tt.roles) + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrSub) + } + if !strings.Contains(err.Error(), tt.wantErrSub) { + t.Errorf("error %q does not contain %q", err.Error(), tt.wantErrSub) + } + }) + } +} + +func TestFinalizeProvisionConfig_AMIIDAloneIsValid(t *testing.T) { + // Passing -ami-id with no owner/filter should be accepted: the binary + // uses the literal AMI and skips DescribeImages. + cfg := validProvisionConfig() + cfg.AMIOwner = "" + cfg.AMIFilter = "" + cfg.AMIID = "ami-pinned" + if err := finalizeProvisionConfig(&cfg, "SELinuxE2E=true", "primary"); err != nil { + t.Errorf("ami-id-only config rejected: %v", err) + } +} + +func TestFinalizeCleanupConfig_RejectsBadInput(t *testing.T) { + tests := []struct { + name string + cfg CleanupConfig + wantErrSub string + }{ + { + name: "missing region", + cfg: CleanupConfig{RunID: "run-42"}, + wantErrSub: "-region is required", + }, + { + name: "missing run-id with sweep enabled", + cfg: CleanupConfig{Region: "us-west-2"}, + wantErrSub: "-run-id is required", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + err := finalizeCleanupConfig(&tt.cfg, "") + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrSub) + } + if !strings.Contains(err.Error(), tt.wantErrSub) { + t.Errorf("error %q does not contain %q", err.Error(), tt.wantErrSub) + } + }) + } +} + +func TestFinalizeCleanupConfig_SkipSweepRelaxesRunID(t *testing.T) { + // With -skip-sweep, run-id isn't needed (no tag-based discovery). + // This is the only way to opt out of the sweep-needs-run-id check. + cfg := CleanupConfig{Region: "us-west-2", SkipSweep: true} + if err := finalizeCleanupConfig(&cfg, ""); err != nil { + t.Errorf("skip-sweep without run-id rejected: %v", err) + } +} diff --git a/.github/workflows/test-aws-test-infra.yaml b/.github/workflows/test-aws-test-infra.yaml new file mode 100644 index 0000000..3a73932 --- /dev/null +++ b/.github/workflows/test-aws-test-infra.yaml @@ -0,0 +1,29 @@ +name: Test aws-test-infra + +on: + push: + branches: [main] + paths: + - '.github/actions/aws-test-infra/**' + pull_request: + paths: + - '.github/actions/aws-test-infra/**' + +jobs: + test: + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0 + with: + go-version-file: .github/actions/aws-test-infra/src/go.mod + - name: Run tests + run: go test -v -race -count=1 ./... + working-directory: .github/actions/aws-test-infra/src + - name: Verify binary builds + working-directory: .github/actions/aws-test-infra/src + run: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -trimpath -ldflags="-s -w" -o /dev/null . diff --git a/Makefile b/Makefile index f719e2b..9717582 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test test-semver-validation test-linear-pr-commenter test-release-notification test-linear-release-sync test-cleanup-head-charts test-ci-test-notify test-auto-approve-bot-prs test-ai-pr-review test-ai-step test-publish-helm-chart test-govulncheck test-go-licenses test-run-ginkgo test-sticky-pr-comment test-repository-dispatch build-linear-release-sync lint install-auto-doc generate-docs check-docs help +.PHONY: test test-semver-validation test-linear-pr-commenter test-release-notification test-linear-release-sync test-aws-test-infra test-cleanup-head-charts test-ci-test-notify test-auto-approve-bot-prs test-ai-pr-review test-ai-step test-publish-helm-chart test-govulncheck test-go-licenses test-run-ginkgo test-sticky-pr-comment test-repository-dispatch build-linear-release-sync lint install-auto-doc generate-docs check-docs help ACTIONS_DIR := .github/actions WORKFLOWS_DIR := .github/workflows @@ -93,7 +93,7 @@ check-docs: generate-docs ## verify docs are up to date (fails if drift detected help: ## show this help @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " %-30s %s\n", $$1, $$2}' -test: test-semver-validation test-linear-pr-commenter test-release-notification test-linear-release-sync test-cleanup-head-charts test-auto-approve-bot-prs test-ai-pr-review test-ai-step test-ci-test-notify test-go-licenses test-publish-helm-chart test-govulncheck test-run-ginkgo test-sticky-pr-comment test-repository-dispatch ## run all action tests +test: test-semver-validation test-linear-pr-commenter test-release-notification test-linear-release-sync test-aws-test-infra test-cleanup-head-charts test-auto-approve-bot-prs test-ai-pr-review test-ai-step test-ci-test-notify test-go-licenses test-publish-helm-chart test-govulncheck test-run-ginkgo test-sticky-pr-comment test-repository-dispatch ## run all action tests test-semver-validation: ## run semver-validation unit tests cd $(ACTIONS_DIR)/semver-validation && npm ci --silent && NODE_OPTIONS=--experimental-vm-modules npx jest --ci --coverage --watchAll=false @@ -107,6 +107,9 @@ test-release-notification: ## run release-notification detect-branch tests test-linear-release-sync: ## run linear-release-sync unit tests cd $(ACTIONS_DIR)/linear-release-sync/src && go test -v ./... +test-aws-test-infra: ## run aws-test-infra unit tests + cd $(ACTIONS_DIR)/aws-test-infra/src && go test -v -race -count=1 ./... + test-cleanup-head-charts: ## run cleanup-head-charts bats tests bats $(SCRIPTS_DIR)/cleanup-head-charts/test/cleanup-head-charts.bats