NumPy参考 >例行程序 >线性代数(numpy.linalg) > numpy.linalg.multi_dot
numpy.linalg.
multi_dot
(数组)[源代码] ¶在一个函数调用中计算两个或多个数组的点积,同时自动选择最快的评估顺序。
multi_dot
链numpy.dot
并使用矩阵的最佳括号[1] [2]。根据矩阵的形状,这可以大大加快乘法。
如果第一个参数是1-D,则将其视为行向量。如果最后一个参数是1-D,则将其视为列向量。其他参数必须为二维。
认为multi_dot
是:
def multi_dot(arrays): return functools.reduce(np.dot, arrays)
如果第一个参数是1-D,则将其视为行向量。如果最后一个参数是1-D,则将其视为列向量。其他参数必须为二维。
返回提供的数组的点积。
也可以看看
dot
两个参数的点乘法。
笔记
矩阵乘法的成本可以通过以下函数计算:
def cost(A, B):
return A.shape[0] * A.shape[1] * B.shape[1]
假设我们有三个矩阵
。
两种不同的括号的成本如下:
cost((AB)C) = 10*100*5 + 10*5*50 = 5000 + 2500 = 7500
cost(A(BC)) = 10*100*50 + 100*5*50 = 50000 + 25000 = 75000
参考文献
例子
multi_dot
允许您编写:
>>> from numpy.linalg import multi_dot
>>> # Prepare some data
>>> A = np.random.random((10000, 100))
>>> B = np.random.random((100, 1000))
>>> C = np.random.random((1000, 5))
>>> D = np.random.random((5, 333))
>>> # the actual dot multiplication
>>> _ = multi_dot([A, B, C, D])
代替:
>>> _ = np.dot(np.dot(np.dot(A, B), C), D)
>>> # or
>>> _ = A.dot(B).dot(C).dot(D)