import { Stack, useTheme } from '@mui/material';
import { AxisBottom, AxisLeft, AxisRight, AxisTop } from '@visx/axis';
import { Grid } from '@visx/grid';
import { Group } from '@visx/group';
import { scaleBand, scaleLinear, scaleOrdinal, StringLike } from '@visx/scale';
import { Bar } from '@visx/shape';
import { defaultStyles, Tooltip, useTooltip } from '@visx/tooltip';
import { useCallback, useMemo } from 'react';

import { Colors } from './constants';
import { Legend } from './Legend';
import { BarChartVerticalProps, TooltipData } from './types';
import { getChartContainerDirection, getChartDimensions } from './utils';

const defaultMargin = { top: 40, left: 50, right: 40, bottom: 100 };

export function BarChart<T extends Record<string, string | number>>({
  data,
  xKey,
  yKey,
  width,
  height,
  showLabels = true,
  barPadding = 0.4,
  barColors = [Colors.BLUE],
  labelSize = 12,
  labelColor = Colors.BLUE,
  tooltips = true,
  tooltipTextColor = Colors.WHITE,
  tooltipBackgroundColor = Colors.BLACK,
  axisTop = false,
  axisLeft = true,
  axisBottom = true,
  axisRight = false,
  axisColor = Colors.BLUE,
  showGrid = true,
  gridColor = Colors.BLACK,
  gridOpacity = 0.07,
  margin = defaultMargin,
  showLegends = false,
  legendPosition = 'left',
  legendDirection = 'column',
  legendJustify = 'center',
  legendAlign = 'center',
  legendIconShape = 'circle',
  legendFontSize = 12,
  legendContainerSize = 0.2,
  onClickHandler,
}: BarChartVerticalProps<T>) {
  const {
    tooltipData,
    tooltipLeft,
    tooltipTop,
    tooltipOpen,
    showTooltip,
    hideTooltip,
  } = useTooltip<TooltipData>();

  const [
    svgWidth,
    svgHeight,
    chartWidth,
    chartHeight,
    legendContainerWidth,
    legendContainerHeight,
  ] = getChartDimensions(
    width,
    height,
    margin,
    showLegends,
    legendPosition,
    legendContainerSize
  );

  const theme = useTheme();

  const tooltipStyles = {
    ...defaultStyles,
    minWidth: 60,
    backgroundColor: tooltipBackgroundColor,
    color: tooltipTextColor,
    zIndex: theme.zIndex.tooltip,
  };

  const getX = useCallback((d: T): string => String(d[xKey]), [xKey]);

  const getY = useCallback(
    (d: T): number => (!isNaN(Number(d[yKey])) ? Number(d[yKey]) : 0),
    [yKey]
  );

  const getId = useCallback((d: T): StringLike => d[yKey], [yKey]);

  // scales, memoize for performance
  const xScale = useMemo(
    () =>
      scaleBand<string>({
        range: [0, chartWidth],
        round: true,
        domain: data.map(getX),
        padding: barPadding,
      }),
    [chartWidth, data, getX, barPadding]
  );
  const yScale = useMemo(
    () =>
      scaleLinear<number>({
        range: [chartHeight, 0],
        round: true,
        domain: [0, Math.max(...data.map(getY))],
      }),
    [chartHeight, data, getY]
  );

  const colorScale = useMemo(
    () =>
      scaleOrdinal<string | number, string>({
        domain: data.map(getX),
        range: barColors,
      }),
    [data, getX, barColors]
  );

  return width < 10 ? null : (
    <Stack
      direction={getChartContainerDirection(legendPosition)}
      height="100%"
      maxWidth="fit-content"
    >
      <svg width={svgWidth} height={svgHeight}>
        {showGrid && (
          <Grid
            top={margin.top}
            left={margin.left}
            xScale={xScale}
            yScale={yScale}
            width={chartWidth}
            height={chartHeight}
            stroke={gridColor}
            strokeOpacity={gridOpacity}
          />
        )}
        <Group top={margin.top} left={margin.left}>
          {data.map((d, index) => {
            const id = getId(d);
            const xValue = getX(d);
            const yValue = getY(d);
            const barWidth = xScale.bandwidth();
            const barHeight = chartHeight - yScale(yValue ?? 0);
            const barX = xScale(xValue);
            const barY = chartHeight - barHeight;

            return (
              <Group
                key={`bar-${id}-${index}`}
                x1={barX}
                y1={barY}
                width={barWidth}
                height={barHeight ? barHeight : 1}
              >
                <Bar
                  x={barX}
                  y={barY}
                  width={barWidth}
                  height={barHeight}
                  fill={colorScale(xValue)}
                  onMouseLeave={
                    tooltips
                      ? () => {
                          hideTooltip();
                        }
                      : undefined
                  }
                  onMouseMove={
                    tooltips
                      ? () => {
                          showTooltip({
                            tooltipData: { key: xValue, value: yValue },
                            tooltipTop: barY,
                            tooltipLeft: barX,
                          });
                        }
                      : undefined
                  }
                  style={{ cursor: 'pointer' }}
                  onClick={onClickHandler ? () => onClickHandler(d) : undefined}
                />
                {showLabels && (
                  <text
                    fontSize={labelSize}
                    textAnchor="middle"
                    fill={labelColor}
                    x={(barX ?? 0) + barWidth / 2}
                    y={barY - labelSize / 2}
                  >
                    {yValue}
                  </text>
                )}
              </Group>
            );
          })}
        </Group>
        {axisTop && (
          <AxisTop
            left={margin.left}
            top={margin.top}
            scale={xScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'middle',
              dy: '-0.33em',
            })}
          />
        )}
        {axisBottom && (
          <AxisBottom
            left={margin.left}
            top={chartHeight + margin.top}
            scale={xScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'middle',
            })}
          />
        )}
        {axisLeft && (
          <AxisLeft
            left={margin.left}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'end',
              dy: '0.33em',
            })}
          />
        )}
        {axisRight && (
          <AxisRight
            left={chartWidth + margin.left}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'start',
              dy: '0.33em',
            })}
          />
        )}
      </svg>
      {showLegends && (
        <Legend
          legendPosition={legendPosition}
          legendDirection={legendDirection}
          legendJustify={legendJustify}
          legendAlign={legendAlign}
          legendFontSize={legendFontSize}
          legendIconShape={legendIconShape}
          maxWidth={legendContainerWidth}
          maxHeight={legendContainerHeight}
          margin={margin}
          colorScale={colorScale}
        />
      )}
      {tooltips && tooltipOpen && tooltipData && (
        <Tooltip top={tooltipTop} left={tooltipLeft} style={tooltipStyles}>
          <div>
            <strong>{tooltipData.key}</strong>
          </div>
          <div>{tooltipData.value}</div>
        </Tooltip>
      )}
    </Stack>
  );
}
