Source code for negmas.checkpoints

"""Implements Checkpoint functionality for easy dumping and restoration of any `NamedComponent` in negmas."""

import shutil
from pathlib import Path
from typing import Optional, Union, Dict, Any, Callable, Type, List

import numpy as np

from negmas import NamedObject
from negmas.helpers import load


class CheckpointMixin:
    """Adds the ability to save checkpoints to a `NamedObject` """
    def checkpoint_init(
        self,
        step_attrib: str = "current_step",
        every: int = 1,
        folder: Optional[Union[str, Path]] = None,
        filename: str = None,
        info: Dict[str, Any] = None,
        exist_ok: bool = True,
        single: bool = True,
    ):
        """
        Initializes the object to automatically save a checkpoint

        Args:
            step_attrib: The attribute that defines the current step. If None, there is no step concept
            every: Number of steps per checkpoint. If < 1 no checkpoints will be saved
            folder: The directory to store checkpoints under
            filename: Name of the file to save the checkpoint under. If None, a unique name will be choosen.
                                 If `single_checkpoint` was False, then multiple files will be used prefixed with the
                                 step number
            info: Any extra information to save in the json file associated with each checkpoint
            exist_ok: Override existing files if any
            single: If True, only the most recent checkpoint will be kept

        Remarks:

            - single_checkpoint implies exist_ok

        """
        self.__checkpoint_every = -1 if folder is None else every
        self.__checkpoint_folder = folder
        self.__checkpoint_extra_info = info
        self.__checkpoint_exist_ok = exist_ok
        self.__checkpoint_single = single
        self.__step_atrrib = step_attrib
        self.__checkpoint_filename = filename

    def checkpoint_on_step_started(self) -> Optional[Path]:
        """Should be called on every step to save checkpoints as needed.

        Returns:
            The path on which the checkpoint is stored if one is stored. None otherwise.

        Remarks:

            - Should be called at the BEGINNING of every step before any processing takes place
        """
        if self.__checkpoint_every < 1 or self.__checkpoint_folder is None:
            return None
        step = getattr(self, self.__step_atrrib)
        if step % self.__checkpoint_every == 0 or self.__checkpoint_every==1:
            me: NamedObject = self  # type: ignore
            return me.checkpoint(
                path=self.__checkpoint_folder,
                file_name=self.__checkpoint_filename,
                info=self.__checkpoint_extra_info,
                exist_ok=self.__checkpoint_exist_ok or self.__checkpoint_single,
                single_checkpoint=self.__checkpoint_single,
                step_attribs=(self.__step_atrrib,),
            )

    def checkpoint_final_step(self) -> Optional[Path]:
        """Should be called at the end of the simulation to save the final state

        Remarks:
            - Should be called after all processing of the final step is conducted.
        """
        if self.__checkpoint_every < 1 or self.__checkpoint_folder is None:
            return None
        me: NamedObject = self  # type: ignore
        return me.checkpoint(
            path=self.__checkpoint_folder,
            file_name=self.__checkpoint_filename,
            info=self.__checkpoint_extra_info,
            exist_ok=True,
            single_checkpoint=self.__checkpoint_single,
            step_attribs=(self.__step_atrrib,),
        )


class CheckpointRunner:
    """Runs an object based on its checkpoints saved in an earlier run"""

    def __init__(
        self,
        folder: Union[str, Path],
        id: str = None,
        callback: Callable[[NamedObject, int], None] = None,
        watch: bool = False,
        object_type: Type[NamedObject] = NamedObject,
    ):
        self.__folder = Path(folder).absolute()
        if id is None:
            pattern = "*.json"
        else:
            pattern = "*id*.json"
        self.__infos = [load(_) for _ in self.__folder.glob(pattern)]
        self.__files = dict(
            zip(
                (_["step"] for _ in self.__infos), (_["filename"] for _ in self.__infos)
            )
        )
        self.__sorted_steps = sorted(list(self.__files.keys()))
        self._step_index = -1
        self.__object: NamedObject = None
        self.__object_type = object_type
        self.__callbacks = []
        if callback is not None:
            self.register_callback(callback)
        self.__watch = watch
        if watch:
            raise NotImplementedError("File watching is not implemented yet")

    @property
    def current_step(self) -> int:
        """Gets the current step number"""
        if self._step_index < 0:
            return -1
        return self.__sorted_steps[self._step_index]

    def step(self) -> Optional[int]:
        """Go one step forward in the stored steps.

        Returns:
            The number of the current step or None if we are already on the last step.

        """
        nxt_step = self._step_index + 1
        if len(self.__sorted_steps) > nxt_step:
            return self.goto(self.__sorted_steps[nxt_step], exact=True)
        return None

    def run(self):
        """Run all steps. Notice that if `register_callback` was used to register some callback functions, they will
        be called for every stored stepped during the run."""
        while self.step() is not None:
            pass

    @property
    def loaded_object(self) -> Optional[NamedObject]:
        """The object stored in the current checkpoint"""
        return self.__object

    def fork(
        self,
        copy_past_checkpoints: bool = False,
        every: int = 1,
        folder: Optional[Union[str, Path]] = None,
        filename: str = None,
        info: Dict[str, Any] = None,
        exist_ok: bool = True,
        single: bool = True,
    ) -> Optional[NamedObject]:
        """
        Creates a copy of the internal object that can be run safely.

        Args:
            copy_past_checkpoints: If true, all checkpoints upto and including current_step will be copied to the given
                                   folder
            every: Number of steps per checkpoint. If < 1 no checkpoints will be saved
            folder: The directory to store checkpoints under
            filename: Name of the file to save the checkpoint under. If None, a unique name will be chosen.
                                 If `single_checkpoint` was False, then multiple files will be used prefixed with the
                                 step number
            info: Any extra information to save in the json file associated with each checkpoint
            exist_ok: Override existing files if any
            single: If True, only the most recent checkpoint will be kept

        Returns:

        """
        if self.__object is None:
            return None
        if (
            not isinstance(self.__object, CheckpointMixin)
            and folder is not None
            and every > 0
        ):
            raise ValueError(
                f"Object of type {self.__object.__class__.__name__} is not implementing the "
                f"CheckpointMixin. It cannot be forked"
            )
        if copy_past_checkpoints and folder is None:
            raise ValueError(
                "Cannot copy past checkpoints because no folder for new checkpoints is given"
            )

        if folder is not None:
            folder = Path(folder).absolute()

        if copy_past_checkpoints:
            files = [v for k, v in self.__files.items() if k <= self.current_step]
            for f in files:
                shutil.copy(str(f), str(folder / Path(f).name))
                shutil.copy(str(f) + ".json", str(folder / (Path(f).name + ".json")))
        x = self.__object
        if isinstance(self.__object, CheckpointMixin):
            CheckpointMixin.checkpoint_init(
                x,
                every=every,
                folder=folder,
                filename=filename,
                info=info,
                exist_ok=exist_ok,
                single=single,
            )
        return x

    @property
    def steps(self) -> List[int]:
        """A list of all stored steps"""
        return self.__sorted_steps

    def register_callback(self, callback: Callable[[NamedObject, int], None]) -> None:
        """Registers a callback to be called whenever a new step is loaded

        Args:
            callback: A callable that takes the named object (after it is loaded and an integer specifying the step
                      number and returns None.

        """
        self.__callbacks.append(callback)

    def goto(self, step: int, exact=False) -> Optional[step]:
        """Goes to the nearest step for the given one returning the exact step number.

        Args:

            step: The step we want to goto
            exact: If True, must go to the exact step number, otherwise go to the nearest step stored in a checkpoint

        Returns:

            - None if the current step is the nearest to the given step. Otherwise the exact step we moved to

        """
        if step is None:
            return
        step_index = np.searchsorted(self.__sorted_steps, step, side="left")
        if step_index > 0 and self.__sorted_steps[step_index - 1] == step:
            step_index -= 1
        if not exact:
            n = len(self.__sorted_steps)
            if step_index > n:
                step = n - 1
            else:
                step = self.__sorted_steps[step_index]

        if self._step_index > -1 and step == self.__sorted_steps[self._step_index]:
            return None
        filename = self.__files.get(step)
        self.__object = self.__object_type.from_checkpoint(filename, return_info=False)
        self._step_index = step_index
        for callback in self.__callbacks:
            callback(self.__object, step)
        return step

    def reset(self) -> None:
        """Goes before the first step"""
        self._step_index = -1
        self.__object = None

    @property
    def next_step(self) -> Optional[int]:
        """Get the  next stored step number (None if it does not exist)"""
        nxt = self._step_index + 1
        if len(self.__sorted_steps) > nxt:
            return self.__sorted_steps[nxt]
        return None

    @property
    def previous_step(self) -> Optional[int]:
        """Get the  previous stored step number (None if it does not exist)"""
        if self._step_index < 0:
            return -1
        nxt = self._step_index - 1
        if 0 <= nxt:
            return self.__sorted_steps[nxt]
        return None

    @property
    def last_step(self) -> Optional[int]:
        """Get the  last stored step number (None if it does not exist)"""
        return self.__sorted_steps[-1]

    @property
    def first_step(self) -> Optional[int]:
        """Get the  first stored step number (None if it does not exist)"""
        return self.__sorted_steps[0]