import {
  Paper,
  Table,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
  tableCellClasses,
  tableRowClasses,
} from '@mui/material';
import { AppColors } from '@omnivivo/style';
import { groupBy, pick } from 'lodash';
import { type ReactNode } from 'react';

type RowData = Record<string, ReactNode>;

interface ColumnDef<TRow extends RowData> {
  align?: 'left' | 'right' | 'center';
  headerAlign?: 'left' | 'right' | 'center';
  field: string;
  headerName: string;
  renderCell?: (row: TRow) => ReactNode;
}

interface DenseTableProps<TColumn extends ColumnDef<TRow>, TRow extends RowData> {
  columns: TColumn[];
  label: string;
  rows: TRow[];
  rowGrouping?: Array<keyof TRow>;
  striped?: boolean;
}

const defaultStyles = {
  [`.${tableCellClasses.root}`]: {
    verticalAlign: 'top',
  },
  [`.${tableCellClasses.head}`]: {
    fontWeight: 'bold',
    textWrap: 'nowrap',
    borderColor: AppColors.GREY1,
  },
  [`.${tableCellClasses.body}.grouped`]: {
    borderBottom: 'none',
  },
  [`.${tableRowClasses.root}.odd .${tableCellClasses.body}`]: {
    backgroundColor: AppColors.GREY6,
  },
};

export default function DenseTable<TColumn extends ColumnDef<TRow>, TRow extends RowData>(
  props: DenseTableProps<TColumn, TRow>
) {
  const { columns, label, rows, rowGrouping = [], striped = false } = props;

  const groupedRows = Object.values(groupBy(rows, (row) => Object.values(pick(row, rowGrouping)).join('|')));

  return (
    <TableContainer component={Paper}>
      <Table sx={defaultStyles} size="small" aria-label={label}>
        <TableHead>
          <TableRow>
            {columns.map((column) => {
              const { field, headerAlign, headerName } = column;
              return (
                <TableCell key={field} align={headerAlign}>
                  {headerName}
                </TableCell>
              );
            })}
          </TableRow>
        </TableHead>
        <TableBody>
          {groupedRows.flatMap((rowGroup, groupIndex) =>
            rowGroup.map((row, rowIndex) => {
              if (rowIndex === 0) {
                return (
                  <TableRow className={striped && groupIndex % 2 === 1 ? 'odd' : ''}>
                    {columns.map((column, columnIndex) => {
                      const { align = 'left', field, renderCell = () => row[column.field] } = column;

                      const isGrouped = rowGrouping.includes(column.field);
                      return (
                        <TableCell
                          key={field}
                          align={align}
                          component={columnIndex === 0 ? 'th' : 'td'}
                          scope="row"
                          rowSpan={isGrouped ? rowGroup.length : 1}
                          className={!isGrouped && rowIndex + 1 !== rowGroup.length ? 'grouped' : ''}
                        >
                          {renderCell(row)}
                        </TableCell>
                      );
                    })}
                  </TableRow>
                );
              }
              return (
                <TableRow className={striped && groupIndex % 2 === 1 ? 'odd' : ''}>
                  {columns
                    .filter((column) => !rowGrouping.includes(column.field))
                    .map((column) => {
                      const { align = 'left', field, renderCell = () => row[column.field] } = column;
                      return (
                        <TableCell
                          key={field}
                          align={align}
                          className={rowIndex + 1 !== rowGroup.length ? 'grouped' : ''}
                        >
                          {renderCell(row)}
                        </TableCell>
                      );
                    })}
                </TableRow>
              );
            })
          )}
        </TableBody>
      </Table>
    </TableContainer>
  );
}
