Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
- [x] Release inference code
- [x] Release beetle part segmentation dataset
- [ ] Release online demo
- [ ] Release one-shot fine-tuning (OC-CCL) code
- [x] Release Open-Close Cycle Consistency Loss (OC-CCL) fine-tuning code
- [x] Release trait retrieval code
- [x] Release butterfly trait segmentation dataset

Expand Down Expand Up @@ -104,6 +104,39 @@ python code/segment.py --support_image /path/to/sample/image.png \
--output /path/to/output/folder \
--output_format "png" # png or gif, optional
```
### Fine-tuning with OC-CCL
OC-CCL (Open-Close Cycle Consistency Loss) fine-tunes SAM2 on a target species. The cycle opens with `reference → query` (predict the query mask) and closes with `query → reference` (predict the closing mask back on the reference), supervised against the reference's GT mask with BCE + Dice.

**1. Get the butterfly images.** Mask annotations are already tracked under `data/cambridge_butterfly/DataSet_Butterfly/`. The image manifest with Zenodo URLs and md5 checksums is committed at `data/cambridge_butterfly/images.csv`. Download with [`cautious-robot`](https://github.com/Imageomics/cautious-robot):
```bash
pip install cautious-robot
cautious-robot -i data/cambridge_butterfly/images.csv \
-o data/cambridge_butterfly/images \
--checksum-algorithm md5 --verifier-col md5
```
Images land at `data/cambridge_butterfly/images/<image_id>.<ext>`. cautious-robot skips existing files, retries 429/5xx responses, and verifies every download against the committed md5. The manifest can be regenerated from the per-species `train_test_separate/*.json` files via `python data/cambridge_butterfly/build_download_csv.py` (queries the Zenodo API for fresh checksums).

**2. Train on one or more species.**
```bash
python src/sst/oc_ccl.py \
--checkpoint checkpoints/sam2_hiera_large.pt \
--species "(malleti x plesseni) x malleti" \
--epochs 10 --lr 1e-5 \
--output_dir outputs/oc_ccl
```
Best checkpoint is written to `<output_dir>/best_model.pt`. Defaults: `--lr 1e-5`, `--batch_size 1`, `--epochs 10`.

**3. Reproduce the ablation grid.** 16 runs across 8 GPUs sweeping learning rate, BCE/Dice weighting, LoRA rank, and memory reset:
```bash
bash experiments/launch_ablations.sh
python experiments/eval_all_ablations.py # writes outputs/ablation/eval_results.json
```

**4. Curriculum variant (top-n% by reconstruction quality).** Precomputes per-sample cycle reconstruction IoU, then trains only on the highest-quality fraction:
```bash
python experiments/curriculum_oc_ccl.py --gpu 0 --epochs 10 --lr 1e-6
```

### Trait-Based Retrieval
For trait-based retrieval, please refer to the demo code below:
```bash
Expand Down
79 changes: 79 additions & 0 deletions data/cambridge_butterfly/build_download_csv.py
Comment thread
egrace479 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Build images.csv from the train/test JSON splits, enriched with md5 checksums
fetched from the Zenodo public API. The CSV is consumed by cautious-robot:

cautious-robot -i images.csv -o images/ \
--checksum-algorithm md5 --verifier-col md5

Maintenance script — re-run only if the train_test_separate/*.json files change.
The output images.csv is checked into the repo so users do not need network
access until they actually download images.
"""

import csv
import json
import urllib.request
from pathlib import Path

DATA_ROOT = Path(__file__).resolve().parent
SPLIT_DIR = DATA_ROOT / "train_test_separate"
OUT_CSV = DATA_ROOT / "images.csv"


def parse_zenodo_url(url):
"""https://zenodo.org/record/<id>/files/<name> -> (id, name)."""
record, _, name = url.split("/record/")[1].partition("/files/")
return record, name


def fetch_record_md5s(record_id):
"""Return {filename: md5_hex} for one Zenodo record."""
api_url = f"https://zenodo.org/api/records/{record_id}"
with urllib.request.urlopen(api_url) as r:
meta = json.load(r)
md5s = {}
for f in meta.get("files", []):
digest = f.get("checksum", "")
if digest.startswith("md5:"):
md5s[f["key"]] = digest[len("md5:"):]
return md5s


def main():
entries = {}
for json_file in sorted(SPLIT_DIR.rglob("*.json")):
for image_id, url, _mask in json.load(open(json_file)):
entries.setdefault(image_id, url)

record_ids = sorted({parse_zenodo_url(url)[0] for url in entries.values()})
print(f"Fetching md5s for {len(record_ids)} Zenodo records...")
md5_by_record = {}
for rid in record_ids:
md5_by_record[rid] = fetch_record_md5s(rid)
print(f" record {rid}: {len(md5_by_record[rid])} files")

rows = []
missing = []
for image_id, url in sorted(entries.items()):
rid, name = parse_zenodo_url(url)
md5 = md5_by_record.get(rid, {}).get(name)
if md5 is None:
missing.append((image_id, url))
continue
ext = Path(url).suffix
rows.append((f"{image_id}{ext}", url, md5))

if missing:
raise SystemExit(
f"Missing md5 for {len(missing)} files (first 5): {missing[:5]}"
)

with open(OUT_CSV, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["filename", "file_url", "md5"])
w.writerows(rows)
print(f"Wrote {len(rows)} rows to {OUT_CSV}")


if __name__ == "__main__":
main()
Loading