from ij import IJ, WindowManager
from ij.plugin.frame import RoiManager
from ij.gui import EllipseRoi, GenericDialog
from ij.plugin import ImageCalculator
from ij.process import ImageProcessor
from ij.measure import Measurements, ResultsTable
from ij.io import SaveDialog


def measure_all(im_plus, frame):
    vals = []
    WindowManager.setTempCurrentImage(im_plus)
    im_plus.setSlice(frame)
    rois = RoiManager.getInstance().getRoisAsArray()
    for roi in rois:
        im_plus.setRoi(roi)
        vals.append(im_plus.getStatistics(measures).mean)
    return vals


measures = Measurements.MEAN
imageList = list(WindowManager.getImageTitles())
imageList.sort()
imageList.append('None')

gd = GenericDialog('Set options')
gd.addChoice('Mask Image', imageList, imageList[0])
for name in imageList[1:-1]:
    gd.addChoice('Other Image', imageList, name)
gd.addCheckbox('Do all images?', False)
gd.addStringField('Find Maxima noise threshold', '600')
gd.addStringField('Nucleus size (radius in pixels)', '3')
gd.addStringField('Override auto thresholding? (0 for auto)', '0')
gd.addStringField('Particle size filter (pixels^2)', '20-1600')
gd.centerDialog(True)
gd.showDialog()

if gd.wasOKed():

    image_names = map(lambda choice: choice.getSelectedItem(), gd.getChoices())
    image_names = filter(lambda name: name != 'None', image_names)
    print image_names
    ips = map(lambda name: WindowManager.getImage(name), image_names)
    mask_image = ips[0]
    out = ResultsTable()
    RoiManager()

    if gd.getCheckboxes()[0].getState():
        frames = range(1, mask_image.getNSlices() + 1)
    else:
        frames = [mask_image.getCurrentSlice()]

    strings = map(lambda text: text.getText(), gd.getStringFields())
    nuc_size = int(strings[1])

    for frame in frames:
        mask_image.setSlice(frame)
        mask_image.deleteRoi()
        WindowManager.setCurrentWindow(WindowManager.getWindow(image_names[0]))
        roiM = RoiManager.getInstance()
        roiM.reset()
        IJ.run("Duplicate...", "title=thresh")
        thresh = IJ.getImage()
        IJ.run("Colors...",
               "foreground=white background=black selection=yellow")

        IJ.run("Smooth")
        IJ.run("Duplicate...", "title=mask")
        mask = IJ.getImage()
        IJ.run("Find Maxima...", "noise=" +
               strings[0] + " output=[Point Selection] exclude")

        # If no maxima found, return |0,0>

        if mask.getRoi():

            points = list(mask.getRoi().getContainedPoints())

            IJ.run("8-bit")
            IJ.run("Select All")
            IJ.run("Clear", "slice")
            mask_proc = mask.getProcessor()

            # for point in points:
            #     x, y = point.x, point.y
            #     circle = EllipseRoi(x - nuc_size, y - nuc_size,
            #                         x + nuc_size, y + nuc_size, 1)
            #     mask.setRoi(circle, False)
            #     IJ.run('Fill')
            for i in xrange(len(points)):
                x, y = points[i].x, points[i].y
                circle = EllipseRoi(x - nuc_size, y - nuc_size,
                                    x + nuc_size, y + nuc_size, 1)
                mask_proc.fill(circle)
                if i % 64 == 0:
                    mask.updateAndDraw()

            mask.updateAndDraw()

            IJ.run("Invert LUT")
            IJ.run("Make Binary")
            WindowManager.setCurrentWindow(WindowManager.getWindow("thresh"))
            if strings[2] == '0':
                IJ.setAutoThreshold(thresh, "Li dark")
            else:
                thresh.getProcessor().setThreshold(
                    int(strings[2]), 65535, ImageProcessor.NO_LUT_UPDATE)
            IJ.run("Make Binary")
            ic = ImageCalculator()
            ic.run("AND thresh mask", thresh, mask)

            IJ.run("Analyze Particles...",
                   "size=" + strings[3] + " display exclude clear summarize add")

            for i in xrange(len(ips)):
                means = measure_all(ips[i], frame)
                while out.getCounter() < len(means):
                    out.incrementCounter()
                for j in xrange(len(means)):
                    out.setValue(image_names[i] + ' #' + str(frame), j, means[j])

        else:
            for i in xrange(len(image_names)):
                out.setValue(image_names[i] + ' #' + str(frame), 0, 0)

        thresh.changes = False
        mask.changes = False
        thresh.close()
        mask.close()

        out.show("Output")

    if gd.getCheckboxes()[0].getState():

        sd = SaveDialog("Save Results", "nuclei", ".txt")
        file_name = sd.getDirectory() + sd.getFileName()
        socket = open(file_name, 'w')
        print file_name
        socket.write('\t'.join(out.getHeadings()) + '\n')
        for i in xrange(out.getCounter()):
            row = out.getRowAsString(i).split('\t')
            for j in xrange(len(row)):
                if float(row[j]) == 0.0:
                    row[j] = ''
            socket.write('\t'.join(row) + '\n')
        socket.close()
