编写自定义数组容器#

numpy 版本 v1.16 中引入的 Numpy 调度机制是编写与 numpy API 兼容并提供 numpy 功能的自定义实现的自定义 N 维数组容器的推荐方法。应用程序包括dask数组(分布在多个节点上的 N 维数组)和cupy数组(GPU 上的 N 维数组)。

为了了解编写自定义数组容器的感觉,我们将从一个简单的示例开始,该示例的实用性相当有限,但说明了所涉及的概念。

>>> import numpy as np
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None):
...         return self._i * np.eye(self._N, dtype=dtype)

我们的自定义数组可以实例化为:

>>> arr = DiagonalArray(5, 1)
>>> arr
DiagonalArray(N=5, value=1)

numpy.array我们可以使用or 转换为 numpy 数组numpy.asarray,它将调用其__array__方法来获取标准numpy.ndarray.

>>> np.asarray(arr)
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

arr如果我们使用 numpy 函数 进行操作,numpy 将再次使用该__array__接口将其转换为数组,然后以通常的方式应用该函数。

>>> np.multiply(arr, 2)
array([[2., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0.],
       [0., 0., 2., 0., 0.],
       [0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 2.]])

请注意,返回类型是标准的numpy.ndarray

>>> type(np.multiply(arr, 2))
<class 'numpy.ndarray'>

我们如何通过这个函数传递我们的自定义数组类型? Numpy 允许类表明它希望通过接口__array_ufunc__和 来以自定义定义的方式处理计算__array_function__。让我们一次选取一个,从 开始__array_ufunc__。此方法涵盖 通用函数 (ufunc),这是一类函数,其中包括 numpy.multiply和 等numpy.sin

接收__array_ufunc__

  • ufunc,像这样的函数numpy.multiply

  • method,一个字符串,区分、等numpy.multiply(...)变体。对于常见情况,, .numpy.multiply.outernumpy.multiply.accumulatenumpy.multiply(...)method == '__call__'

  • inputs,这可能是不同类型的混合

  • kwargs, 传递给函数的关键字参数

对于这个例子,我们将只处理该方法__call__

>>> from numbers import Number
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None):
...         return self._i * np.eye(self._N, dtype=dtype)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != self._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = self._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented

现在我们的自定义数组类型通过 numpy 函数传递。

>>> arr = DiagonalArray(5, 1)
>>> np.multiply(arr, 3)
DiagonalArray(N=5, value=3)
>>> np.add(arr, 3)
DiagonalArray(N=5, value=4)
>>> np.sin(arr)
DiagonalArray(N=5, value=0.8414709848078965)

此时不行。arr + 3

>>> arr + 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for +: 'DiagonalArray' and 'int'

为了支持它,我们需要定义 Python 接口__add____lt__等来分派到相应的 ufunc。我们可以通过继承 mixin 来方便地实现这一点 NDArrayOperatorsMixin

>>> import numpy.lib.mixins
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None):
...         return self._i * np.eye(self._N, dtype=dtype)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != self._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = self._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
>>> arr = DiagonalArray(5, 1)
>>> arr + 3
DiagonalArray(N=5, value=4)
>>> arr > 0
DiagonalArray(N=5, value=True)

现在我们来解决一下__array_function__。我们将创建将 numpy 函数映射到我们的自定义变体的字典。

>>> HANDLED_FUNCTIONS = {}
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None):
...         return self._i * np.eye(self._N, dtype=dtype)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 # In this case we accept only scalar numbers or DiagonalArrays.
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != self._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = self._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
...     def __array_function__(self, func, types, args, kwargs):
...         if func not in HANDLED_FUNCTIONS:
...             return NotImplemented
...         # Note: this allows subclasses that don't override
...         # __array_function__ to handle DiagonalArray objects.
...         if not all(issubclass(t, self.__class__) for t in types):
...             return NotImplemented
...         return HANDLED_FUNCTIONS[func](*args, **kwargs)
...

一个方便的模式是定义一个implements可用于向HANDLED_FUNCTIONS.

>>> def implements(np_function):
...    "Register an __array_function__ implementation for DiagonalArray objects."
...    def decorator(func):
...        HANDLED_FUNCTIONS[np_function] = func
...        return func
...    return decorator
...

现在我们为 编写 numpy 函数的实现DiagonalArray。为了完整起见,为了支持用法,arr.sum()添加一个sum调用 的方法numpy.sum(self),对于 也同样mean

>>> @implements(np.sum)
... def sum(arr):
...     "Implementation of np.sum for DiagonalArray objects"
...     return arr._i * arr._N
...
>>> @implements(np.mean)
... def mean(arr):
...     "Implementation of np.mean for DiagonalArray objects"
...     return arr._i / arr._N
...
>>> arr = DiagonalArray(5, 1)
>>> np.sum(arr)
5
>>> np.mean(arr)
0.2

如果用户尝试使用 中未包含的任何 numpy 函数 HANDLED_FUNCTIONSTypeErrornumpy 将引发 a ,表明不支持此操作。例如,连接两个 DiagonalArrays不会产生另一个对角数组,因此不支持。

>>> np.concatenate([arr, arr])
Traceback (most recent call last):
...
TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]

此外,我们的sum和实现mean不接受 numpy 实现所接受的可选参数。

>>> np.sum(arr, axis=0)
Traceback (most recent call last):
...
TypeError: sum() got an unexpected keyword argument 'axis'

用户始终可以选择从那里转换为普通并使用标准 numpy numpy.ndarraynumpy.asarray

>>> np.concatenate([np.asarray(arr), np.asarray(arr)])
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

为了简洁起见,本示例中的实现DiagonalArray仅处理 np.sum和函数。 np.meanNumpy API 中的许多其他函数也可用于包装,并且成熟的自定义数组容器可以显式支持 Numpy 可用于包装的所有函数。

Numpy 提供了一些实用程序来帮助测试在 命名空间中实现__array_ufunc__和协议的自定义数组容器。__array_function__numpy.testing.overrides

要检查 Numpy 函数是否可以通过重写__array_ufunc__,您可以使用allows_array_ufunc_override

>>> from np.testing.overrides import allows_array_ufunc_override
>>> allows_array_ufunc_override(np.add)
True

__array_function__同样,您可以检查是否可以通过使用 来覆盖函数 allows_array_function_override

Numpy API 中每个可重写函数的列表也可通过 get_overridable_numpy_array_functions支持__array_function__协议的函数和 get_overridable_numpy_ufuncs支持协议的函数获得__array_ufunc__。这两个函数都会返回 Numpy 公共 API 中存在的函数集。用户定义的 ufunc 或依赖于 Numpy 的其他库中定义的 ufunc 不存在于这些集中。

有关自定义数组容器的更完整的示例,请参阅dask 源代码cupy 源代码。

另请参阅NEP 18