"""Batch navigator: JFrame panel for reviewing p63+ results across all images in a run."""

import os

from javax.swing import (
    JFrame, JTable, JScrollPane, JPanel, JButton, JLabel, JSlider,
    ButtonGroup, JToggleButton, BorderFactory, SwingUtilities,
    WindowConstants, ListSelectionModel,
)
from javax.swing.table import DefaultTableModel
from javax.swing.event import ListSelectionListener, ChangeListener
from java.awt import (
    BorderLayout, FlowLayout, GridBagLayout, GridBagConstraints,
    Dimension, Insets,
)
from java.awt.event import ActionListener, WindowAdapter, MouseListener
from java.util import Hashtable
from java.util.concurrent import CountDownLatch
from java.lang import Runnable

from ij import IJ, CompositeImage
from ij.process import ImageConverter

import display
import review


class _TableModel(DefaultTableModel):
    """Read-only DefaultTableModel -- cells are never editable."""
    def isCellEditable(self, row, col):
        return False


class _CellTable(object):
    """Floating JFrame showing per-cell results for the currently-selected image."""

    def __init__(self, on_row_selected):
        self._on_row_selected = on_row_selected
        self._rois    = []   # list of (roi, mean_255) for current image
        self._updating = False

        self._model  = _TableModel(["Cell", "Red Mean", "p63+"], 0)
        self._jtable = JTable(self._model)
        self._jtable.setSelectionMode(ListSelectionModel.SINGLE_SELECTION)
        self._jtable.setRowHeight(18)
        self._jtable.setAutoResizeMode(JTable.AUTO_RESIZE_LAST_COLUMN)
        self._jtable.getColumnModel().getColumn(0).setPreferredWidth(55)
        self._jtable.getColumnModel().getColumn(1).setPreferredWidth(80)
        self._jtable.getColumnModel().getColumn(2).setPreferredWidth(45)

        ct = self
        class _SelListener(ListSelectionListener):
            def valueChanged(self, e):
                if e.getValueIsAdjusting() or ct._updating:
                    return
                row = ct._jtable.getSelectedRow()
                if row >= 0:
                    ct._on_row_selected(row)

        self._jtable.getSelectionModel().addListSelectionListener(_SelListener())

        scroll = JScrollPane(self._jtable)
        scroll.setPreferredSize(Dimension(220, 420))

        self._frame = JFrame("p63 Cell Results")
        self._frame.setDefaultCloseOperation(WindowConstants.HIDE_ON_CLOSE)
        self._frame.setContentPane(scroll)
        self._frame.pack()

    def show(self):
        self._frame.setVisible(True)
        self._frame.toFront()

    def update_all(self, rois, labels, threshold):
        """Rebuild table rows for a new image."""
        self._rois = rois
        self._updating = True
        try:
            self._model.setRowCount(0)
            for i, (roi, mean) in enumerate(rois):
                lbl = labels[i] if i < len(labels) else (i + 1)
                self._model.addRow([
                    str(lbl),
                    "{:.1f}".format(mean),
                    "+" if mean >= threshold else "-",
                ])
        finally:
            self._updating = False
        self._jtable.clearSelection()

    def update_p63_col(self, threshold):
        """Update the p63+ column in-place without disturbing the row selection."""
        self._updating = True
        try:
            for i, (roi, mean) in enumerate(self._rois):
                self._model.setValueAt("+" if mean >= threshold else "-", i, 2)
        finally:
            self._updating = False

    def select_row(self, idx):
        """Programmatically highlight row idx and scroll to it."""
        self._updating = True
        try:
            self._jtable.setRowSelectionInterval(idx, idx)
            self._jtable.scrollRectToVisible(
                self._jtable.getCellRect(idx, 0, True)
            )
        finally:
            self._updating = False


class _Navigator(object):
    """Swing JFrame that lists all images and provides per-image threshold review."""

    def __init__(self, summary_rows, cells_by_image, image_map, output_dir,
                 variant_str, threshold_col, seg_channel, quant_channel,
                 results_table, latch):
        self._summary_rows    = summary_rows
        self._cells_by_image  = cells_by_image
        self._image_map       = image_map
        self._output_dir      = output_dir
        self._variant_str     = variant_str
        self._seg_channel     = seg_channel
        self._quant_channel   = quant_channel
        self._results_table   = results_table
        self._latch           = latch

        self._image_names = [r["image"] for r in summary_rows]

        # Seed per-image thresholds from summary.csv
        self._otsu_thresholds    = {}
        self._current_thresholds = {}
        for row in summary_rows:
            name = row["image"]
            try:
                t = float(row[threshold_col])
            except (KeyError, ValueError):
                t = 128.0
            self._otsu_thresholds[name]    = t
            self._current_thresholds[name] = t

        # Runtime state
        self._current_image    = None
        self._current_imp      = None
        self._channel_mode     = "merge"
        self._roi_cache        = {}
        self._label_cache      = {}
        self._accepted_images  = set()
        self._updating_slider  = False
        self._cell_table       = None   # created in _build_ui

        self._build_ui()

    # ------------------------------------------------------------------
    # UI construction
    # ------------------------------------------------------------------

    def _build_ui(self):
        # ---- Table ----
        col_names = ["Image", "Cells", "p63+ %", "Threshold", "Status"]
        self._table_model = _TableModel(col_names, 0)
        vl = self._variant_str.lower()
        for row in self._summary_rows:
            name    = row["image"]
            n_cells = row.get("n_cells", "?")
            try:
                frac    = float(row.get(vl + "_p63_fraction", 0.0)) * 100.0
                frac_s  = "{:.1f}%".format(frac)
            except (ValueError, TypeError):
                frac_s  = "?"
            t = self._otsu_thresholds.get(name, 128.0)
            self._table_model.addRow([name, n_cells, frac_s, "{:.1f}".format(t), "Pending"])

        self._table = JTable(self._table_model)
        self._table.setSelectionMode(ListSelectionModel.SINGLE_SELECTION)
        self._table.setRowHeight(20)
        self._table.setAutoResizeMode(JTable.AUTO_RESIZE_LAST_COLUMN)
        self._table.getColumnModel().getColumn(0).setPreferredWidth(220)
        self._table.getColumnModel().getColumn(1).setPreferredWidth(50)
        self._table.getColumnModel().getColumn(2).setPreferredWidth(70)
        self._table.getColumnModel().getColumn(3).setPreferredWidth(80)
        self._table.getColumnModel().getColumn(4).setPreferredWidth(80)

        n_visible = min(len(self._image_names), 15)
        scroll = JScrollPane(self._table)
        scroll.setPreferredSize(Dimension(560, max(80, self._table.getRowHeight() * n_visible + 30)))

        # ---- Channel toggle buttons ----
        self._btn_red   = JToggleButton("Red (p63)")
        self._btn_blue  = JToggleButton("Blue (DAPI)")
        self._btn_merge = JToggleButton("Merge")
        self._btn_merge.setSelected(True)
        grp = ButtonGroup()
        grp.add(self._btn_red)
        grp.add(self._btn_blue)
        grp.add(self._btn_merge)

        # ---- Threshold slider ----
        # Internal range is *100 (0-25500) so the value represents a float
        # 0.00-255.00 with 0.01 precision. Tick labels are remapped to show
        # the real 0-255 scale.
        self._slider = JSlider(0, 25500, 12800)
        self._slider.setMajorTickSpacing(6400)
        self._slider.setMinorTickSpacing(1600)
        self._slider.setPaintTicks(True)
        self._slider.setPaintLabels(True)
        self._slider.setPreferredSize(Dimension(300, 50))
        self._slider.setEnabled(False)
        _tick_labels = Hashtable()
        for _v, _lbl in [(0, "0"), (6400, "64"), (12800, "128"),
                         (19200, "192"), (25500, "255")]:
            _tick_labels.put(_v, JLabel(_lbl))
        self._slider.setLabelTable(_tick_labels)

        self._count_label = JLabel("Select an image above to begin")

        self._btn_reset      = JButton("Reset to Otsu")
        self._btn_reset.setEnabled(False)
        self._btn_accept     = JButton("Accept for this image")
        self._btn_accept.setEnabled(False)
        self._btn_accept_all = JButton("Accept All & Close")

        # ---- Wire listeners ----
        nav = self

        class _TableListener(ListSelectionListener):
            def valueChanged(self, e):
                if e.getValueIsAdjusting():
                    return
                row = nav._table.getSelectedRow()
                if row >= 0:
                    nav._select_image(nav._image_names[row])

        class _SliderListener(ChangeListener):
            def stateChanged(self, e):
                if nav._updating_slider or nav._current_image is None:
                    return
                t = nav._slider.getValue() / 100.0
                if nav._current_imp is not None:
                    rois = nav._roi_cache.get(nav._current_image, [])
                    n_pos, n_neg = review._recolor_rois(rois, t, None, nav._current_imp)
                    nav._count_label.setText(
                        "p63+: {}  p63-: {}  (T={:.2f})".format(n_pos, n_neg, t)
                    )
                    nav._cell_table.update_p63_col(t)

        # Three separate listener classes avoid the Jython Java-proxy issue that
        # arises when an interface implementor has a custom __init__ with args.
        class _RedListener(ActionListener):
            def actionPerformed(self, e):
                nav._channel_mode = "red"
                nav._apply_channel_mode()

        class _BlueListener(ActionListener):
            def actionPerformed(self, e):
                nav._channel_mode = "blue"
                nav._apply_channel_mode()

        class _MergeListener(ActionListener):
            def actionPerformed(self, e):
                nav._channel_mode = "merge"
                nav._apply_channel_mode()

        class _ResetListener(ActionListener):
            def actionPerformed(self, e):
                if nav._current_image is None:
                    return
                otsu = nav._otsu_thresholds.get(nav._current_image, 128.0)
                nav._updating_slider = True
                nav._slider.setValue(int(round(otsu * 100)))
                nav._updating_slider = False
                if nav._current_imp is not None:
                    rois = nav._roi_cache.get(nav._current_image, [])
                    n_pos, n_neg = review._recolor_rois(rois, otsu, None, nav._current_imp)
                    nav._count_label.setText(
                        "p63+: {}  p63-: {}  (T={:.2f})".format(n_pos, n_neg, otsu)
                    )
                    nav._cell_table.update_p63_col(otsu)

        class _AcceptListener(ActionListener):
            def actionPerformed(self, e):
                if nav._current_image is not None:
                    nav._accept_image(nav._current_image, nav._slider.getValue() / 100.0)

        class _AcceptAllListener(ActionListener):
            def actionPerformed(self, e):
                nav._accept_all_and_close()

        class _CloseListener(WindowAdapter):
            def windowClosing(self, e):
                nav._accept_all_and_close()

        self._slider_listener = _SliderListener()
        self._table.getSelectionModel().addListSelectionListener(_TableListener())
        self._slider.addChangeListener(self._slider_listener)
        self._btn_red.addActionListener(_RedListener())
        self._btn_blue.addActionListener(_BlueListener())
        self._btn_merge.addActionListener(_MergeListener())
        self._btn_reset.addActionListener(_ResetListener())
        self._btn_accept.addActionListener(_AcceptListener())
        self._btn_accept_all.addActionListener(_AcceptAllListener())

        # ---- Layout ----
        channel_panel = JPanel(FlowLayout(FlowLayout.LEFT, 4, 2))
        channel_panel.add(JLabel("Channel:"))
        channel_panel.add(self._btn_red)
        channel_panel.add(self._btn_blue)
        channel_panel.add(self._btn_merge)

        slider_panel = JPanel(FlowLayout(FlowLayout.LEFT, 4, 2))
        slider_panel.add(JLabel("Threshold (0-255):"))
        slider_panel.add(self._slider)

        count_panel = JPanel(FlowLayout(FlowLayout.LEFT, 4, 2))
        count_panel.add(self._count_label)
        count_panel.add(self._btn_reset)

        action_panel = JPanel(FlowLayout(FlowLayout.RIGHT, 4, 2))
        action_panel.add(self._btn_accept)
        action_panel.add(self._btn_accept_all)

        ctrl_panel = JPanel(GridBagLayout())
        ctrl_panel.setBorder(BorderFactory.createEmptyBorder(6, 4, 4, 4))
        gbc = GridBagConstraints()
        gbc.gridx   = 0
        gbc.fill    = GridBagConstraints.HORIZONTAL
        gbc.weightx = 1.0
        gbc.insets  = Insets(1, 0, 1, 0)
        for idx, panel in enumerate([channel_panel, slider_panel, count_panel, action_panel]):
            gbc.gridy = idx
            ctrl_panel.add(panel, gbc)

        main_panel = JPanel(BorderLayout())
        main_panel.setBorder(BorderFactory.createEmptyBorder(8, 8, 4, 8))
        main_panel.add(scroll, BorderLayout.CENTER)
        main_panel.add(ctrl_panel, BorderLayout.SOUTH)

        self._frame = JFrame("p63 Batch Navigator")
        self._frame.setDefaultCloseOperation(WindowConstants.DO_NOTHING_ON_CLOSE)
        self._frame.addWindowListener(_CloseListener())
        self._frame.setContentPane(main_panel)
        self._frame.pack()
        self._frame.setLocationRelativeTo(None)

        self._cell_table = _CellTable(self._on_results_row_selected)

    # ------------------------------------------------------------------
    # Image selection
    # ------------------------------------------------------------------

    def _select_image(self, name):
        # Disable controls during load
        self._slider.setEnabled(False)
        self._btn_reset.setEnabled(False)
        self._btn_accept.setEnabled(False)
        self._count_label.setText("Loading {} ...".format(name))

        # Close previous image window
        if self._current_imp is not None:
            prev = self._current_imp
            self._current_imp = None
            prev.close()

        self._current_image = name
        orig_path = self._image_map.get(name)
        mask_path = os.path.join(self._output_dir, name + "_cp_masks.tif")

        if not orig_path or not os.path.isfile(str(orig_path)):
            self._count_label.setText("[WARN] Image file not found: " + name)
            return
        if not os.path.isfile(mask_path):
            self._count_label.setText("[WARN] Mask not found for: " + name)
            return

        imp = IJ.openImage(str(orig_path))
        if imp is None:
            self._count_label.setText("[WARN] Could not open image: " + name)
            return
        if imp.getBitDepth() == 24:
            # Convert 24-bit RGB to a 3-channel 8-bit composite so setMode/setC
            # work. convertRGBtoRGBStack modifies imp in place (R=ch1, G=ch2,
            # B=ch3), matching quantChannel=1 (p63/red) and segChannel=3 (DAPI).
            ImageConverter(imp).convertToRGBStack()
            imp.setDimensions(3, 1, 1)
            imp = CompositeImage(imp, CompositeImage.COMPOSITE)
        elif imp.getNChannels() > 1 and not isinstance(imp, CompositeImage):
            imp = CompositeImage(imp, CompositeImage.COMPOSITE)
        # Assign LUTs so single-channel COLOR mode shows the right hue.
        # ch1=Red (p63), ch2=Green, ch3=Blue (DAPI).
        try:
            imp.setC(1); IJ.run(imp, "Red", "")
            imp.setC(2); IJ.run(imp, "Green", "")
            imp.setC(3); IJ.run(imp, "Blue", "")
            imp.setC(1)
        except Exception:
            pass
        self._current_imp = imp

        # Build ROI list (lazy; cached after first load)
        if name not in self._roi_cache:
            try:
                mask_imp = display.load_mask(mask_path)
                rois_with_labels = display.masks_to_rois(mask_imp)
                mask_imp.close()
            except Exception as exc:
                self._count_label.setText("[WARN] ROI load failed for {}: {}".format(name, str(exc)))
                return
            cells = self._cells_by_image.get(name, [])
            mean_by_label = {int(c["cell_label"]): float(c["red_mean"]) for c in cells}
            self._label_cache[name] = [lbl for lbl, roi in rois_with_labels]
            self._roi_cache[name] = [
                (roi, mean_by_label.get(lbl, 0.0))
                for lbl, roi in rois_with_labels
            ]

        rois = self._roi_cache[name]
        t = self._current_thresholds.get(name, 128.0)

        # Update slider without firing the change listener
        self._updating_slider = True
        self._slider.setValue(int(round(t * 100)))
        self._slider.setEnabled(True)
        self._updating_slider = False

        n_pos, n_neg = review._recolor_rois(rois, t, None, imp)
        imp.show()
        self._apply_channel_mode()

        # Attach a mouse listener so clicking a cell highlights its table row.
        _nav_ref = self
        class _CanvasListener(MouseListener):
            def mouseClicked(self, e):
                src = e.getSource()
                _nav_ref._on_canvas_click(
                    src.offScreenX(e.getX()),
                    src.offScreenY(e.getY()),
                )
            def mousePressed(self, e): pass
            def mouseReleased(self, e): pass
            def mouseEntered(self, e): pass
            def mouseExited(self, e): pass
        canvas = imp.getCanvas()
        if canvas is not None:
            canvas.addMouseListener(_CanvasListener())

        labels = self._label_cache.get(name, [])
        self._cell_table.update_all(rois, labels, t)
        self._cell_table.show()

        self._btn_reset.setEnabled(True)
        self._btn_accept.setEnabled(True)
        self._count_label.setText("p63+: {}  p63-: {}  (T={:.2f})".format(n_pos, n_neg, t))

    # ------------------------------------------------------------------
    # Channel display
    # ------------------------------------------------------------------

    def _apply_channel_mode(self):
        imp = self._current_imp
        if imp is None:
            return
        n_ch = imp.getNChannels()
        if n_ch < 2:
            return
        if self._channel_mode == "merge":
            try:
                imp.setMode(CompositeImage.COMPOSITE)
            except Exception:
                pass
        else:
            ch = self._quant_channel if self._channel_mode == "red" else self._seg_channel
            ch = min(max(ch, 1), n_ch)
            try:
                imp.setMode(CompositeImage.COLOR)
            except Exception:
                pass
            imp.setC(ch)
        imp.updateAndDraw()

    # ------------------------------------------------------------------
    # Bidirectional selection helpers
    # ------------------------------------------------------------------

    def _on_results_row_selected(self, row):
        """Called when the user clicks a row in the cell table; highlights that ROI."""
        if self._current_imp is None or self._current_image is None:
            return
        rois = self._roi_cache.get(self._current_image, [])
        if row < len(rois):
            self._current_imp.setRoi(rois[row][0])

    def _on_canvas_click(self, img_x, img_y):
        """Called when the user clicks the image canvas; selects the matching table row."""
        name = self._current_image
        if name is None:
            return
        rois = self._roi_cache.get(name, [])
        for idx, (roi, _mean) in enumerate(rois):
            if roi.contains(img_x, img_y):
                self._cell_table.select_row(idx)
                break

    # ------------------------------------------------------------------
    # Accept / save
    # ------------------------------------------------------------------

    def _accept_image(self, name, final_t):
        otsu = self._otsu_thresholds.get(name, 128.0)
        was_manual = abs(float(final_t) - otsu) > 0.5
        cells = self._cells_by_image.get(name, [])
        review._write_final_calls(cells, final_t, self._output_dir, was_manual, self._variant_str)
        self._current_thresholds[name] = float(final_t)
        self._accepted_images.add(name)
        for i, iname in enumerate(self._image_names):
            if iname == name:
                self._table_model.setValueAt("Reviewed", i, 4)
                self._table_model.setValueAt("{:.1f}".format(float(final_t)), i, 3)
                break

    def _accept_all_and_close(self):
        # Collect all unreviewed images and write both CSVs in a single pass.
        pending_thresholds = {}
        pending_manual     = {}
        for name in self._image_names:
            if name not in self._accepted_images:
                t    = self._current_thresholds.get(name, self._otsu_thresholds.get(name, 128.0))
                otsu = self._otsu_thresholds.get(name, 128.0)
                pending_thresholds[name] = t
                pending_manual[name]     = abs(t - otsu) > 0.5
        if pending_thresholds:
            try:
                review._write_final_calls_batch(
                    self._cells_by_image, pending_thresholds,
                    self._output_dir, pending_manual, self._variant_str,
                )
            except Exception as exc:
                IJ.log("[WARN] Could not save thresholds: {}".format(str(exc)))
        if self._current_imp is not None:
            self._current_imp.close()
            self._current_imp = None
        if self._cell_table is not None:
            self._cell_table._frame.dispose()
        self._frame.dispose()
        self._latch.countDown()


# ------------------------------------------------------------------
# Public entry point
# ------------------------------------------------------------------

def show_navigator(summary_rows, cells_by_image, image_map, output_dir,
                   variant_str, threshold_col, seg_channel, quant_channel,
                   results_table):
    """Create and show the batch navigator JFrame; block until user closes or accepts all.

    summary_rows   -- list of dicts from read_summary_csv
    cells_by_image -- {image_name: [cell_row_dicts]} pre-grouped from read_cells_csv
    image_map      -- {basename: absolute_path_string} for original images
    output_dir     -- string path to output directory
    seg_channel    -- 1-based int (blue/DAPI channel)
    quant_channel  -- 1-based int (red/p63 channel)
    results_table  -- ImageJ ResultsTable populated with all cells
    """
    if not summary_rows:
        return

    latch = CountDownLatch(1)
    nav_holder = [None]

    class _Build(Runnable):
        def run(self):
            nav = _Navigator(
                summary_rows, cells_by_image, image_map, output_dir,
                variant_str, threshold_col, seg_channel, quant_channel,
                results_table, latch,
            )
            nav_holder[0] = nav
            nav._frame.setVisible(True)

    SwingUtilities.invokeLater(_Build())
    latch.await()
