"""CPython entry point for the p63 sidecar. Run inside the cellpose4_p63 conda env.""" import argparse import csv import os import re import sys import traceback import numpy as np import tifffile from thresholding import ( extract_red, get_cell_data, compute_thresholds, count_p63_positive, ) VARIANTS = ("T1", "T2", "T7", "T14") _IMAGE_EXTS = {".tif", ".tiff", ".png", ".jpg", ".jpeg"} _CELLS_HEADER = [ "image", "cell_label", "red_mean", "t1_positive", "t2_positive", "t7_positive", "t14_positive", "variant_call", ] _SUMMARY_HEADER = [ "image", "n_cells", "t1_threshold", "t2_threshold", "t7_threshold", "t14_threshold", "t1_p63_positive", "t2_p63_positive", "t7_p63_positive", "t14_p63_positive", "t1_p63_fraction", "t2_p63_fraction", "t7_p63_fraction", "t14_p63_fraction", ] def parse_args(): """Define and parse CLI arguments; return an argparse.Namespace.""" p = argparse.ArgumentParser( description="Run Cellpose segmentation + p63 quantification on a folder of images.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument("--input-dir", required=True, help="Folder of original multi-channel images") p.add_argument("--output-dir", required=True, help="Where masks and CSVs are written") p.add_argument("--model", default="cpsam", help="'cpsam' (default) or path to a custom Cellpose model") p.add_argument("--seg-channel", type=int, default=3, help="1-based channel index for segmentation (blue/DAPI)") p.add_argument("--quant-channel", type=int, default=1, help="1-based channel index for quantification (red/p63)") p.add_argument("--variant", default="T14", choices=list(VARIANTS), help="Threshold variant that defines the primary p63+ call") p.add_argument("--atten-cap", type=float, default=0.7, help="Attenuation cap for T2/T14 variants") p.add_argument("--use-gpu", action="store_true", help="Run Cellpose on GPU") p.add_argument("--pooling-mode", default="per_image", choices=["per_image", "pool_all", "pool_replicate", "reference_replicate"], help="How to pool images when computing Otsu thresholds") p.add_argument("--reference-replicate", default="", help="Replicate whose cells seed the threshold applied to all images " "(e.g. 'R1'); only used when --pooling-mode=reference_replicate") p.add_argument("--batch-size", type=int, default=8, help="Images per Cellpose eval call; reduce if you run out of memory") return p.parse_args() def load_model(model_choice, use_gpu): """Load and return a CellposeModel (loaded once, reused for every image). Cellpose 4.0.1+ removed model_type; all named models (cpsam, cyto3, cyto2, nuclei) and custom file paths are specified via pretrained_model. On the first use of a named model Cellpose downloads the weights automatically. """ from cellpose import models return models.CellposeModel(gpu=use_gpu, pretrained_model=model_choice) def _extract_replicate(basename): """Return e.g. 'R1' from 'D1_R1_01', or None if no R-number found.""" m = re.search(r'(?i)_R(\d+)', basename) return ("R" + m.group(1)) if m else None def _normalize_replicate_arg(ref_arg): """Normalize --reference-replicate to 'R' form. Accepts 'R1', 'r1', or bare '1'.""" s = ref_arg.strip() if re.match(r'^\d+$', s): return "R" + s if re.match(r'^[Rr]\d+$', s): return "R" + s[1:] return s def _group_key(basename, pooling_mode): """Return the pooling group key for basename. For reference_replicate, every image maps to '__all__'; Pass 1 controls which images actually contribute cells to that bucket. """ if pooling_mode == "per_image": return basename if pooling_mode in ("pool_all", "reference_replicate"): return "__all__" # pool_replicate rep = _extract_replicate(basename) if rep: return rep print("[WARN] {} has no R-number in name; using per-image key".format(basename), file=sys.stderr) return basename def _load_image_channels(image_path, seg_channel, quant_channel): """Load one image file; return (basename, seg_ch_float32, red_ch). seg_channel and quant_channel are 1-based (Fiji convention). Does not run Cellpose -- called before the batched eval pass. """ basename = os.path.splitext(os.path.basename(image_path))[0] img = tifffile.imread(image_path) seg_ch = extract_red(img, seg_channel - 1).astype(np.float32) red_ch = extract_red(img, quant_channel - 1) return basename, seg_ch, red_ch def _cells_rows(result): """Build cells.csv rows from a finalized result dict (thresholds already set).""" th = result["thresholds"] variant = result["variant"] rows = [] for c in result["cells"]: rows.append({ "image": result["image"], "cell_label": c["label"], "red_mean": round(c["mean"] * 255, 4), # report on 0-255 scale "t1_positive": int(c["mean"] >= th["T1"]), "t2_positive": int(c["mean"] >= th["T2"]), "t7_positive": int(c["mean"] >= th["T7"]), "t14_positive": int(c["mean"] >= th["T14"]), "variant_call": int(c["mean"] >= th[variant]), }) return rows def _summary_row(result): """Build a single summary.csv row from a finalized result dict.""" th = result["thresholds"] n_pos = result["n_pos"] n_cells = result["n_cells"] row = {"image": result["image"], "n_cells": n_cells} for v in VARIANTS: vl = v.lower() frac = round(n_pos[v] / n_cells, 6) if n_cells > 0 else float("nan") row["{}_threshold".format(vl)] = round(th[v] * 255, 4) row["{}_p63_positive".format(vl)] = n_pos[v] row["{}_p63_fraction".format(vl)] = frac return row def _append_csv(path, header, rows): """Write rows to a CSV, writing the header only when the file does not yet exist.""" write_header = not os.path.exists(path) with open(path, "a", newline="") as fh: writer = csv.DictWriter(fh, fieldnames=header) if write_header: writer.writeheader() writer.writerows(rows) def write_cells_csv(output_dir, rows): """Append rows to cells.csv (image, cell_label, red_mean, t1_positive...t14_positive, variant_call).""" _append_csv(os.path.join(output_dir, "cells.csv"), _CELLS_HEADER, rows) def write_summary_csv(output_dir, rows): """Append rows to summary.csv (image, T1..T14 thresholds, p63+ counts/fractions, n_cells).""" _append_csv(os.path.join(output_dir, "summary.csv"), _SUMMARY_HEADER, rows) def main(): """Parse args, load model once, chunked-batch loop: load -> segment+pool -> threshold -> write.""" args = parse_args() os.makedirs(args.output_dir, exist_ok=True) image_files = sorted( f for f in os.listdir(args.input_dir) if os.path.splitext(f)[1].lower() in _IMAGE_EXTS ) if not image_files: print("[ERROR] No images found in {}".format(args.input_dir), file=sys.stderr) sys.exit(1) ref_rep = "" if args.pooling_mode == "reference_replicate": if not args.reference_replicate.strip(): print("[ERROR] --reference-replicate must be set when using " "reference_replicate mode", file=sys.stderr) sys.exit(1) ref_rep = _normalize_replicate_arg(args.reference_replicate) print("Reference replicate for thresholding: {}".format(ref_rep)) print("Loading model '{}' (gpu={})...".format(args.model, args.use_gpu)) model = load_model(args.model, args.use_gpu) # ------------------------------------------------------------------ # Pass 1: load all images; extract seg + quant channels # ------------------------------------------------------------------ loaded = [] # list of {basename, seg_ch, red_ch} n_err = 0 for i, fname in enumerate(image_files, 1): image_path = os.path.join(args.input_dir, fname) print("[{}/{}] loading {}".format(i, len(image_files), fname)) try: basename, seg_ch, red_ch = _load_image_channels( image_path, args.seg_channel, args.quant_channel, ) loaded.append({"basename": basename, "seg_ch": seg_ch, "red_ch": red_ch}) except Exception as exc: print("[ERROR] {}: {}".format(fname, exc), file=sys.stderr) traceback.print_exc() n_err += 1 if not loaded: print("[ERROR] All images failed to load.", file=sys.stderr) sys.exit(1) # ------------------------------------------------------------------ # Pass 2+3: segment in chunks, write masks, extract cells, pool groups # Processing in chunks of --batch-size (default 8) keeps memory bounded # and prints progress so large batches don't appear to hang. # ------------------------------------------------------------------ n = len(loaded) chunk_size = max(1, args.batch_size) print("Segmenting {} image(s) in chunks of {} [pooling_mode={}]...".format( n, chunk_size, args.pooling_mode)) raw_results = [] group_cells = {} for chunk_start in range(0, n, chunk_size): chunk = loaded[chunk_start:chunk_start + chunk_size] chunk_end = chunk_start + len(chunk) print(" segmenting {}-{} of {}...".format(chunk_start + 1, chunk_end, n)) try: seg_inputs = [item["seg_ch"] for item in chunk] masks_list, _, _ = model.eval(seg_inputs, diameter=None, channels=None) if len(chunk) == 1 and not isinstance(masks_list, list): masks_list = [masks_list] except Exception as exc: print("[ERROR] Cellpose eval failed on chunk {}-{}: {}".format( chunk_start + 1, chunk_end, exc), file=sys.stderr) traceback.print_exc() sys.exit(1) for item, masks in zip(chunk, masks_list): basename = item["basename"] mask_path = os.path.join(args.output_dir, "{}_cp_masks.tif".format(basename)) try: tifffile.imwrite(mask_path, masks.astype(np.uint16)) cells = get_cell_data(masks, item["red_ch"]) if not cells: print("[WARN] No cells found in {}".format(basename), file=sys.stderr) except Exception as exc: print("[ERROR] post-segmentation for {}: {}".format(basename, exc), file=sys.stderr) traceback.print_exc() n_err += 1 continue raw = {"image": basename, "cells": cells, "n_cells": len(cells)} raw["group_key"] = _group_key(basename, args.pooling_mode) raw_results.append(raw) if args.pooling_mode == "reference_replicate": img_rep = _extract_replicate(basename) if img_rep and img_rep.upper() == ref_rep.upper(): group_cells.setdefault("__all__", []).extend(cells) else: group_cells.setdefault(raw["group_key"], []).extend(cells) if args.pooling_mode == "reference_replicate": ref_cells = group_cells.get("__all__", []) print("Reference replicate '{}' contributed {} cells.".format(ref_rep, len(ref_cells))) if not ref_cells: print("[ERROR] No cells found for reference replicate '{}'. " "Check filenames and --reference-replicate value.".format(ref_rep), file=sys.stderr) sys.exit(1) # ------------------------------------------------------------------ # Pass 4: compute thresholds per group from pooled cells # ------------------------------------------------------------------ group_thresholds = {} for key, pooled in group_cells.items(): if pooled: group_thresholds[key] = compute_thresholds( pooled, attenuation_cap=args.atten_cap ) th = group_thresholds[key] print("Group '{}': {} cells T1={:.1f} T2={:.1f} T7={:.1f} T14={:.1f}".format( key, len(pooled), th["T1"] * 255, th["T2"] * 255, th["T7"] * 255, th["T14"] * 255, )) else: group_thresholds[key] = {v: 0.0 for v in VARIANTS} print("[WARN] Group '{}' has no cells; thresholds set to 0".format(key), file=sys.stderr) # ------------------------------------------------------------------ # Pass 5: apply group thresholds per image and write CSVs # ------------------------------------------------------------------ for raw in raw_results: thresholds = group_thresholds.get(raw["group_key"], {v: 0.0 for v in VARIANTS}) cells = raw["cells"] n_pos = {v: count_p63_positive(cells, thresholds[v]) for v in VARIANTS} result = { "image": raw["image"], "cells": cells, "thresholds": thresholds, "n_pos": n_pos, "n_cells": raw["n_cells"], "variant": args.variant, } write_cells_csv(args.output_dir, _cells_rows(result)) write_summary_csv(args.output_dir, [_summary_row(result)]) n_ok = len(raw_results) print("Done. OK={} ERR={} Output: {}".format(n_ok, n_err, args.output_dir)) if n_err: sys.exit(1) if __name__ == "__main__": main()