# coding=<'utf-8'> import os import sys import io as sysio sys.path.append(os.path.join(os.path.dirname(__file__), 'img_processing')) from img_processing.pipeline_preprocessing import * from img_processing.pipeline_templatematch import * import numpy as np import pandas as pd from cellpose import io, models from pathlib import Path import cv2 import torch import argparse def crop_expanded_mask(mask, crop_px, edge_threshold): """ Given an integer mask, crop the mask by the specified number of pixels from each edge, and for masks along the interior cropped edges, set the masks to 0 according to the following rules: 1. If the mask crosses the top or right edges, set the mask to 0 if the proportion of the mask within the cropped region compared to outside the cropped region is less than the edge_threshold. 2. If the mask crosses the bottom or left edges, set the mask to 0 if the proportion of the mask within the cropped region compared to outside the cropped region is less than 1 - edge_threshold. Parameters: - mask: numpy array of integer masks (e.g., Cellpose masks). - crop_px: Number of pixels to crop from each edge. - edge_threshold: Proportion threshold for determining whether to keep or remove a mask near an edge. Returns: - cropped_mask: Cropped mask array with specified masks set to 0. """ # Get mask dimensions height, width = mask.shape print(mask.shape) # Create the cropped mask by removing crop_px from each edge cropped_mask = mask[crop_px:height-crop_px, crop_px:width-crop_px] # Get the unique mask labels (excluding 0) unique_masks = np.unique(mask) unique_masks = unique_masks[unique_masks != 0] # Iterate over each unique mask label for m in unique_masks: original_mask = mask == m cropped_region_mask = cropped_mask == m # Calculate the area of the mask inside and outside the cropped region original_area = np.sum(original_mask) cropped_area = np.sum(cropped_region_mask) # Skip masks that do not cross the cropped edges if cropped_area == 0 or original_area == 0: continue # Check if the mask crosses the top or right edges if (np.any(original_mask[:crop_px, :]) or np.any(original_mask[:, -crop_px:])): if cropped_area / original_area < edge_threshold: cropped_mask[cropped_mask == m] = 0 # Check if the mask crosses the bottom or left edges if (np.any(original_mask[-crop_px:, :]) or np.any(original_mask[:, :crop_px])): if cropped_area / original_area < (1 - edge_threshold): cropped_mask[cropped_mask == m] = 0 cropped_mask = renumber_masks(cropped_mask) return cropped_mask def crop_expanded_mask_updated(mask, crop_px, lt_edge_threshold, rb_edge_threshold): """ Given an integer mask, crop the mask by the specified number of pixels from each edge, and for masks along the interior cropped edges, set the masks to 0 according to the following rules: 1. If the mask crosses the top or left edges, set the mask to 0 if the proportion of the mask within the cropped region compared to outside the cropped region is less than the lt_edge_threshold. 2. If the mask crosses the bottom or right edges, set the mask to 0 if the proportion of the mask within the cropped region compared to outside the cropped region is less than rb_edge_threshold. Parameters: - mask: numpy array of integer masks (e.g., Cellpose masks). - crop_px: Number of pixels to crop from each edge. - edge_threshold: Proportion threshold for determining whether to keep or remove a mask near an edge. Returns: - cropped_mask: Cropped mask array with specified masks set to 0. """ # Get mask dimensions height, width = mask.shape print(mask.shape) # Create the cropped mask by removing crop_px from each edge cropped_mask = mask[crop_px:height-crop_px, crop_px:width-crop_px] # Get the unique mask labels (excluding 0) unique_masks = np.unique(mask) unique_masks = unique_masks[unique_masks != 0] # Iterate over each unique mask label for m in unique_masks: original_mask = mask == m cropped_region_mask = cropped_mask == m # Calculate the area of the mask inside and outside the cropped region original_area = np.sum(original_mask) cropped_area = np.sum(cropped_region_mask) # Skip masks that do not cross the cropped edges if cropped_area == 0 or original_area == 0: continue # Check if the mask crosses the top or left edges if (np.any(original_mask[:crop_px, :]) or np.any(original_mask[:, :crop_px])): if cropped_area / original_area < lt_edge_threshold: cropped_mask[cropped_mask == m] = 0 # Check if the mask crosses the bottom or right edges if (np.any(original_mask[-crop_px:, :]) or np.any(original_mask[:, -crop_px:])): if cropped_area / original_area < rb_edge_threshold: cropped_mask[cropped_mask == m] = 0 cropped_mask = renumber_masks(cropped_mask) return cropped_mask def img_gray_clahe(image): # Converts images to grayscale and applies CLAHE # Handle images separated to 3 channels if image.shape[0] == 3: image = image.transpose(1, 2, 0) # If image is grayscale, convert to 3-channel if len(image.shape) == 2 or (len(image.shape) == 3 and image.shape[2] == 1): image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) cl1 = clahe.apply(image) image = np.expand_dims(cl1, axis=2) image = np.repeat(image, 3, axis=2) return image if __name__ == '__main__': # Set up argument parser parser = argparse.ArgumentParser(description='Run Cellpose on a set of images') parser.add_argument('--input_image', dest='input_image', type=str, help='Path to the image file') parser.add_argument('--model_path', dest='model_path', type=str, help='Path to the model file') parser.add_argument('--flow_threshold', dest='flow_threshold', type=float, help='Flow threshold') parser.add_argument('--cellprob_threshold', dest='cellprob_threshold', type=float, help='Cell probability threshold') parser.add_argument('--expand_size', dest='expand_size', type=float, help='Size of expanded edge in pixels') parser.add_argument('--lt_edge_threshold', dest='lt_edge_threshold', type=float, help='Left and top edge threshold for removing masks') parser.add_argument('--rb_edge_threshold', dest='rb_edge_threshold', type=float, help='Right and bottom edge threshold for removing masks') parser.add_argument('--gpu', dest='gpu', action='store_true', required=False, help='Use GPU for processing if available') args = parser.parse_args() device = None gpu = False if args.gpu: # Use GPU/MPS if available if torch.backends.mps.is_available(): device = torch.device('mps') gpu = True elif torch.cuda.is_available(): device = torch.device('cuda') gpu = True else: print("GPU not available, using CPU instead.") device = torch.device('cpu') gpu = False else: device = torch.device('cpu') gpu = False sys.stdout = sysio.TextIOWrapper( sys.stdout.buffer, encoding='utf-8', errors='replace', # or 'ignore' line_buffering=True ) input_image = Path(args.input_image) output_roi_stem = input_image.with_name(input_image.stem) img = cv2.imread(args.input_image, cv2.IMREAD_UNCHANGED) img = img_gray_clahe(img) cv2.imwrite(str(output_roi_stem) + "_clahe.tif", img) model_path = args.model_path pretrained_model = models.CellposeModel(pretrained_model=model_path, gpu=gpu, device=device) flow_threshold = args.flow_threshold cellprob_threshold = args.cellprob_threshold expand_size = int(args.expand_size) lt_edge_threshold = args.lt_edge_threshold rb_edge_threshold = args.rb_edge_threshold # Run Cellpose masks, flows, styles = pretrained_model.eval(img, diameter=None, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, channels=[0, 0]) # Filter expanded masks masks = crop_expanded_mask_updated(masks, expand_size, lt_edge_threshold=lt_edge_threshold, rb_edge_threshold=rb_edge_threshold) # # For Debugging Masks # np_masks = np.asarray(masks) # np.savetxt(str(output_roi_stem) + "_masks.csv", np_masks, delimiter=",") # Output point ROIs as CSV x, y = centroids_from_mask(masks) df = pd.DataFrame({'X': x, 'Y': y}) df.to_csv(str(output_roi_stem) + "_rois.csv", index=False) io.save_rois(masks, output_roi_stem)