import { Typography } from "@mui/material";
import { toast } from "react-toastify";
import { GetGpuCostBreakdownResponse } from "../../../api/fetcher";
import { SCALEOPS_COLORS } from "../../../colors";
import AverageSpan from "../../../components/AverageSpan";
import PartialBorders from "../../../components/PartialBorders";
import RunningNumberBox from "../../Overview/TopOverviewSection/RunningNumberBox";
import useGetGpuMetrics from "./useGetGpuMetrics";
import useGetGpuNodesInfo from "./useGetGpuNodesInfo";
import useGetGpuClusterInfo from "./useGetGpuClusterInfo";

const HALF_DIV_CLASS_WRAP = `h-[134px]`;
const WRAPPER_CLASS_NAME = "w-full h-full";

interface GpuMetricsSummary {
  totalGpuCost: number;
  numGpuNodes: number;
  numGpus: number;
  gpuRequests: number;
}

const gpuUtilizationFromData = (
  data?: GetGpuCostBreakdownResponse
): {
  computeUtilization: number;
  memoryUtilization: number;
} => {
  const defaultRet = {
    computeUtilization: NaN,
    memoryUtilization: NaN,
  };

  if (!data) {
    return defaultRet;
  }

  const validDataPoints =
    data.gpuDataPoints?.filter((dataPoint) => {
      return !dataPoint.missing;
    }) ?? [];

  if (!validDataPoints) {
    return defaultRet;
  }

  let totalSmActive = 0;
  let totalUsedMemory = 0;
  let totalAllocatableMemory = 0;

  for (const dataPoint of validDataPoints) {
    totalSmActive += dataPoint.smActive ?? 0;
    totalUsedMemory += dataPoint.used ?? 0;
    totalAllocatableMemory += dataPoint.total ?? (dataPoint.used ?? 0) + (dataPoint.free ?? 0);
  }

  return {
    computeUtilization: totalSmActive / validDataPoints.length,
    memoryUtilization: totalUsedMemory / totalAllocatableMemory,
  };
};

export default function GpuMetricsSummary() {
  const {
    data: gpuMetricsData,
    isLoading: gpuMetricsIsLoading,
    isError: gpuMetricsIsError,
    error: gpuMetricsError,
  } = useGetGpuMetrics();

  const {
    data: gpuNodesInfoData,
    isLoading: gpuNodesInfoIsLoading,
    isError: gpuNodesInfoIsError,
    error: gpuNodesInfoError,
  } = useGetGpuNodesInfo();

  if (gpuMetricsIsError) {
    toast.error("Failed to get GPU metrics");
    console.error(gpuMetricsError);
    return null;
  }

  if (gpuNodesInfoIsError) {
    toast.error("Failed to get GPU nodes info");
    console.error(gpuNodesInfoError);
    return null;
  }

  const gpuUtilization = gpuUtilizationFromData(gpuMetricsData);
  const avgNumGpuNodes = gpuNodesInfoData?.avgNumGPUNodes ?? 0;
  const avgNumGpuRequests = gpuNodesInfoData?.avgNumGPURequests ?? 0;
  const avgNumGpuCapacity = gpuNodesInfoData?.avgNumGPUCapacity ?? 1;
  const gpuNodesCost = gpuNodesInfoData?.gpuNodesCost ?? 0;

  const {
    data: clusterGpuInfoData,
    isError: clusterGpuInfoIsError,
    error: clusterGpuInfoError,
  } = useGetGpuClusterInfo();

  if (clusterGpuInfoError) {
    console.error(`Error while fetching clusterGpuInfo: ${clusterGpuInfoError.message}`);
  }

  const numGpusInCluster = clusterGpuInfoData?.numGpusInCluster;

  const activeGpusInCluster = numGpusInCluster
    ? numGpusInCluster > 0
    : // If we error, show the GPU cost report to be safe
      clusterGpuInfoIsError;

  return (
    <div className="border bg-white border-border rounded-lg p-5 flex items-center justify-center relative h-[18.75rem]">
      <div className={WRAPPER_CLASS_NAME}>
        <PartialBorders>
          <RunningNumberBox
            title={<>Total GPU nodes cost</>}
            value={activeGpusInCluster ? gpuNodesCost : NaN}
            prefix="$"
            numberVariant="h4"
            numberClassName="text-text-lightBlack"
            isLoading={gpuNodesInfoIsLoading}
          />
        </PartialBorders>
      </div>
      <div className={WRAPPER_CLASS_NAME}>
        <PartialBorders left>
          <RunningNumberBox
            title={
              <>
                GPU allocatable <AverageSpan />
              </>
            }
            value={activeGpusInCluster ? avgNumGpuNodes : NaN}
            numberVariant="h4"
            numberClassName="text-text-lightBlack"
            isLoading={gpuNodesInfoIsLoading}
            showRoundedValue={false}
          />
        </PartialBorders>
      </div>
      <PartialBorders left right>
        <RunningNumberBox
          title={
            <>
              GPU request <AverageSpan />
            </>
          }
          value={activeGpusInCluster ? (avgNumGpuRequests / avgNumGpuCapacity) * 100 : NaN}
          suffix="%"
          barPercentageValue={avgNumGpuRequests / avgNumGpuCapacity}
          barPercentageColor={SCALEOPS_COLORS.guideline.darkYellow}
          numberVariant="h4"
          numberClassName="text-text-lightBlack"
          isLoading={gpuNodesInfoIsLoading}
          barTooltipTitleFormatter={() => (
            <Typography variant="subtitle1">
              Requested an average of <strong>{avgNumGpuRequests.toFixed(2)}</strong> out of{" "}
              <strong>{(avgNumGpuCapacity ?? 1).toFixed(2)}</strong> GPUs in the cluster.
            </Typography>
          )}
        />
      </PartialBorders>
      <PartialBorders>
        <PartialBorders bottom wrapperClassName={HALF_DIV_CLASS_WRAP}>
          <RunningNumberBox
            title={
              <>
                GPU compute usage <AverageSpan />
              </>
            }
            value={activeGpusInCluster ? gpuUtilization.computeUtilization * 100 : NaN}
            numberVariant="h4"
            numberClassName="text-text-lightBlack"
            isLoading={gpuMetricsIsLoading}
            suffix="%"
            barPercentageValue={gpuUtilization.computeUtilization}
            barPercentageColor={SCALEOPS_COLORS.main.blue}
          />
        </PartialBorders>
        <PartialBorders wrapperClassName={HALF_DIV_CLASS_WRAP}>
          <RunningNumberBox
            title={
              <>
                GPU memory usage <AverageSpan />
              </>
            }
            value={activeGpusInCluster ? gpuUtilization.memoryUtilization * 100 : NaN}
            numberVariant="h4"
            numberClassName="text-text-lightBlack"
            isLoading={gpuMetricsIsLoading}
            suffix="%"
            barPercentageValue={gpuUtilization.memoryUtilization}
            barPercentageColor={SCALEOPS_COLORS.main.blue}
          />
        </PartialBorders>
      </PartialBorders>
    </div>
  );
}
