Source code for mygrad._utils.lock_management

"""
Provides utilities responsible for locking/releasing array writeability.
"""
import os
from collections import Counter, defaultdict
from typing import (
    TYPE_CHECKING,
    Counter as CounterType,
    DefaultDict,
    Dict,
    Generator,
    Iterable,
    Set,
)
from weakref import finalize, ref

import numpy as np

from mygrad._utils import ContextTracker, WeakRef, WeakRefIterable

if TYPE_CHECKING:  # pragma: no cover
    from mygrad import Tensor as TensorType


# arr-id -> num active ops involving arr
_array_counter: CounterType[int] = Counter()

# arr-id -> weak-ref of arr, for arrays participating in live ops
_array_tracker: Dict[int, WeakRef[np.ndarray]] = dict()

# maps base-array ID to ID of view that can't be unlocked until
# base is unlocked
_views_waiting_for_unlock: DefaultDict[int, Set[int]] = defaultdict(set)

__all__ = [
    "lock_arr_writeability",
    "release_writeability_lock_on_op",
    "mem_guard_off",
    "mem_guard_on",
    "mem_guard_active",
]


def array_is_tracked(arr: np.ndarray) -> bool:
    """Returns True if the provided array, or a view of it, is currently
    involved in one or more mygrad operation."""
    arr_id = id(arr)
    return arr_id in _array_tracker and _array_tracker[arr_id]() is not None


def lock_arr_writeability(arr: np.ndarray, force_lock: bool = False) -> np.ndarray:
    """Increments the count of active ops that an array is involved in
    and makes the array read-only

    Parameters
    ----------
    arr : numpy.ndarray

    force_lock : bool, optional (default=False)
        If True, and array that is already read-only will be tracked
        for unlocking

    Returns
    -------
    numpy.ndarray
        The locked array"""
    arr_id = id(arr)
    if not array_is_tracked(arr):
        if (
            not force_lock
            and not arr.flags.writeable
            and (arr.base is None or not array_is_tracked(arr.base))
        ):
            # array is natively read-only; don't do anything
            return arr
        # keeps track of array so we can clean up the array
        # counter when tracked arrays fall out of scope
        _array_tracker[arr_id] = ref(arr)
        _array_counter[arr_id] = 1
    else:
        _array_counter[arr_id] += 1
    if arr.flags.writeable is True:
        arr.flags.writeable = False
    return arr


def unique_arrs_and_bases(
    tensors: Iterable["TensorType"],
) -> Generator[np.ndarray, None, None]:
    """
    Yields unique (by-ID) arrays from an iterable. If an array
    has a base, the base is yielded first (assuming that base
    object has not already been yielded).
    """
    seen = set()
    for t in tensors:
        arr = t.data
        arr_id = id(arr)
        if arr_id not in seen:
            # important note!
            # We must yield array bases first so that base's
            # writeability is restored first.
            # Then view's writeability can be restored
            if arr.base is not None:
                base_id = id(arr.base)
                if base_id not in seen:
                    seen.add(base_id)
                    yield arr.base
            seen.add(arr_id)
            yield arr


def _release_lock_on_arr_writeability(arr: np.ndarray):
    """
    Decrements the number of active ops the array participates in.
    An array no longer participating in any ops will have its
    writeability restored.
    """
    arr_id = id(arr)
    num_active_ops = _array_counter[arr_id]

    if num_active_ops == 1:
        # final active op involving array is being de-referenced:
        # okay to unlock array
        del _array_counter[arr_id]

        if arr.base is not None and arr.base.flags.writeable is False:
            # Array is view and must wait until its base is released
            # before it can be unlocked
            # Thus we are still tracking this array
            _views_waiting_for_unlock[id(arr.base)].add(arr_id)
        else:
            # we no longer need to track the array
            arr.flags.writeable = True
            _array_tracker.pop(arr_id, None)
            if not _array_tracker and _views_waiting_for_unlock:
                # If no arrays are being tracked, then there can't
                # be any views waiting to be unlocked.
                # Clean up!
                _views_waiting_for_unlock.clear()
    elif num_active_ops > 0:
        _array_counter[arr_id] = num_active_ops - 1

    if (
        arr.base is None
        and arr.flags.writeable
        and (arr_id in _views_waiting_for_unlock)
    ):
        # array was base of view waiting to be unlocked..
        #
        # Either:
        #    view no longer exists
        #    or view is involved in new op
        #    or view can now get unlocked
        # under all conditions view will no longer be waiting to be unlocked
        for view_arr_id in tuple(_views_waiting_for_unlock[arr_id]):
            if _array_counter[view_arr_id] > 0:
                # view involved in new op
                continue

            _views_waiting_for_unlock[arr_id].remove(view_arr_id)

            try:
                view_arr = _array_tracker.pop(view_arr_id)()
                if view_arr is None:
                    continue
            except KeyError:
                # view array is no longer available for unlocking
                continue

            try:
                view_arr.flags.writeable = True
            except ValueError:  # pragma: no cover
                # sometimes this raises.. but it is not
                # reproducible and is very rare
                pass

        if not _views_waiting_for_unlock[arr_id]:
            _views_waiting_for_unlock.pop(arr_id)


def release_writeability_lock_on_op(arr_refs: WeakRefIterable[np.ndarray]):
    """Marks each array (and for a view, its base) to have its
    writeability lock released.

    An array is made writeable only once all of its locks
    have been released.

    Parameters
    ----------
    arr_refs : WeakRefIterable[np.ndarray]
        The arrays to be unlocked. Only one lock is released
        on each array, even if the same array occurs
        multiple times in the iterable."""
    for arr in arr_refs:
        _release_lock_on_arr_writeability(arr)


MEM_GUARD = os.environ.get("MYGRAD_MEM_GUARD", True)

if MEM_GUARD in {"True", "true", "1", 1, True}:
    MEM_GUARD = True
elif MEM_GUARD in {"False", "false", "0", 0, False}:  # pragma: no cover
    MEM_GUARD = False
else:  # pragma: no cover
    from warnings import warn

    warn(
        f"Environment variable MYGRAD_MEM_GUARD was set to an unknown value {MEM_GUARD}. "
        f"Proceeding with `MEM_GUARD=True`"
    )
    MEM_GUARD = True


class MemStateContext(ContextTracker):
    @property
    def state(self):
        return MEM_GUARD

    @state.setter
    def state(self, value: bool):
        if not isinstance(value, bool):  # pragma: no cover
            raise TypeError(
                f"MEM_GUARD must be set to a boolean value, got {value} (type={type(value)})"
            )

        global MEM_GUARD
        MEM_GUARD = value


class _NoMemGuard(MemStateContext):
    """A context manager used to suspend memory-locking behavior

    Examples
    --------
    >>> from mygrad import  mem_guard_off
    >>> with mem_guard_off:
    ...     # array-memory locking is turned off
    ...     pass
    ... # previous memory-locking behavior is restored

    This can also be used as a decorator

    >>> @mem_guard_off
    >>> def f():
    ...     # array-memory locking is turned off within function
    ...     return

    """

    _enter_set_value = False


class _WithMemGuard(MemStateContext):
    """A context manager used to enable memory-locking behavior

    Examples
    --------
    >>> from mygrad import mem_guard_on
    >>> with mem_guard_on:
    ...     # array-memory locking is turned on
    ...     pass
    ... # previous memory-locking behavior is restored

    This can also be used as a decorator

    >>> @mem_guard_on
    >>> def f():
    ...     # array-memory locking is turned on within function
    ...     return

    """

    _enter_set_value = True


mem_guard_off = _NoMemGuard()
mem_guard_on = _WithMemGuard()


[docs]def turn_memory_guarding_off(): """Globally disables all memory-guarding mechanisms, except for in contexts where they are explicitly enabled. Notes ----- With memory guarding disabled, arrays participating in active computational graphs are not protected from being mutated by the user. Mutating such an array will corrupt the derivatives that are computed via back-propagation, and will produce incorrect results. This can speed up computations involving many small tensors substantially. If you want to disable memory guarding at the system level, you can set the system environment variable MYGRAD_MEM_GUARD=False. NOTE THAT THIS IS NOT RECOMMENDED. See Also -------- turn_memory_guarding_on : Globally enables all memory-guarding mechanisms mem_guard_off : context manager & decorator for suspending memory guarding mem_guard_on : context manager & decorator for enabling memory guarding Examples -------- The following demonstrates how one can unwittingly corrupt backpropagation through a computational graph >>> import mygrad as mg >>> import numpy as np >>> mg.turn_memory_guarding_off() # speeds up calculations, but with risks involved.. >>> x = np.arange(3.) >>> y = mg.ones_like(x) >>> z = x * y >>> x[:] = 0 # mutates x, corrupting state associated with z >>> z.backward() >>> y.grad # would be array([0., 1., 2.]) if graph wasn't corrupted array([0., 0., 0.]) """ global MEM_GUARD MEM_GUARD = False
def turn_memory_guarding_on(): """Globally enables all memory-guarding mechanisms, except for in contexts where they are explicitly disabled. Notes ----- Memory guarding is enabled by default. It ensures that arrays that are participating in computational graphs cannot be mutated (at least unwittingly..), which provides important assurances that the state of the computational graph is not corrupted for back-propagation. Memory guarding can slow down computations involving many small tensors. Realistic worst-case benchmarks suggest a ~50% slowdown. If performance is important, it is recommended that you test your code leaving memory guarding enabled. Presuming the code runs without any errors regarding writing to read-only arrays, you can proceed to disable memory guarding and enjoy the concomitant speedups. Note also that running your code in a `no_autodiff` context will automatically disable memory guarding. See Also -------- turn_memory_guarding_off : Globally enables all memory-guarding mechanisms mem_guard_off : context manager & decorator for suspending memory guarding mem_guard_on : context manager & decorator for enabling memory guarding no_autodiff : context manager for disabling graph-tracking for back propagation Examples -------- The following demonstrates how memory guarding prevents one from unwittingly corrupting an active computational graph >>> import mygrad as mg >>> import numpy as np >>> # (Note that memory guarding is on by default, so >>> # this call isn't usually needed...) >>> mg.turn_memory_guarding_on() >>> x = np.arange(3.) >>> y = mg.ones_like(x) >>> z = x * y >>> try: ... x[:] = 0 # raises because `x` is made read-only ... except ValueError: ... pass >>> z.backward() >>> y.grad # correct gradient is computed array([0., 1., 2.]) """ global MEM_GUARD MEM_GUARD = True def mem_guard_active() -> bool: """Indicates whether or not memory guarding is active. See Also -------- turn_memory_guarding_on : Globally enables all memory-guarding mechanisms turn_memory_guarding_off : Globally enables all memory-guarding mechanisms mem_guard_off : context manager & decorator for suspending memory guarding mem_guard_on : context manager & decorator for enabling memory guarding """ return MEM_GUARD def force_lock_tensor_and_creators(tensor: "TensorType"): unique_arrs = tuple( lock_arr_writeability(arr) for arr in unique_arrs_and_bases(tensor.creator.variables) ) lock_arr_writeability(tensor.data, force_lock=True) tensor_refs = WeakRefIterable(unique_arrs) tensor_refs.append(tensor.data) finalize( tensor.creator, release_writeability_lock_on_op, tensor_refs, )