import React, { useContext, useState } from "react";

export interface SaveChangesContext {
  registerCallback: (callback: CallbackFunction, id: CallbackId) => void;
  invokeCallbacks: () => Promise<void>;
}

export const SaveChangesContext = React.createContext<SaveChangesContext | undefined>(undefined);

interface SaveChangesProviderProps {
  children: React.ReactNode;
}

type CallbackFunction = () => Promise<void>;
type CallbackId = "step1" | "step2" | "step3";

interface CallbackObject {
  callback: CallbackFunction;
  id: CallbackId;
}

export function SaveChangesProvider(props: SaveChangesProviderProps) {
  const [callbacks, setCallbacks] = useState<CallbackObject[]>([]);

  function registerCallback(callback: CallbackFunction, id: CallbackId) {
    setCallbacks(currentCallbacks => {
      // This stops us registering multiple callbacks per step
      const newCallbacks = currentCallbacks.filter(cb => cb.id !== id);
      return [...newCallbacks, { callback, id }];
    });
  }

  async function invokeCallbacks() {
    return await callbacks.reduce(async (prom, callbackObj) => {
      await prom;
      return callbackObj.callback();
    }, Promise.resolve());
  }

  return (
    <SaveChangesContext.Provider value={{ registerCallback, invokeCallbacks }}>
      {props.children}
    </SaveChangesContext.Provider>
  );
}

export function useSaveChangesContext() {
  const context = useContext(SaveChangesContext);

  if (context === undefined) {
    throw new Error("useSaveChangesContext must be used within a SaveChangesProvider");
  }

  return context;
}
