解码NumPy的点积:对维度魔法的简要探索

Exploring the Magic of Dimensionality in NumPy's Dot Product

澄清NumPy点积的混淆

使用DreamStudio生成的图像,提示为“一个充满代码巫师的混乱,黑暗,阴郁,多维世界”

介绍

在处理NumPy中的维度时,我是不是唯一一个经常感到困惑的人?今天,在阅读Gradio的文档页面时,我遇到了以下代码片段:

sepia_filter = np.array([    [0.393, 0.769, 0.189],     [0.349, 0.686, 0.168],    [0.272, 0.534, 0.131],])# input_img shape (H, W, 3)# sepia_filter shape (3, 3)sepia_img = input_img.dot(sepia_filter.T)  # <- 为什么这是合法的?sepia_img /= sepia_img.max()

嘿,嘿,嘿!为什么图像(W,H,3)与滤波器(3,3)的点积是合法的?我让ChatGPT解释给我听,但它开始给我错误的答案(比如说这行不通)或者忽视我的问题,转而回答其他事情。所以,除了动动我的脑筋(再加上阅读文档,叹气),没有其他解决办法。

如果你对上面的代码也有点困惑,请继续阅读。

点积:一个通用例子

来自NumPy点积文档(有所修改):

如果a.shape = (I, J, C),b.shape = (K, C, L),那么dot(a, b)[i, j, k, l] = sum(a[i, j, :] * b[k, :, l])。请注意,“a”的最后一个维度等于“b”的倒数第二个维度。

或者,用代码表示:

I, J, K, L, C = 10, 20, 30, 40, 50a = np.random.random((I, J, C))b = np.random.random((K, C, L))c = a.dot(b)i, j, k, l = 3, 2, 4, 5print(c[i, j, k, l])print(sum(a[i, j, :] * b[k, :, l]))

输出(相同结果):

13.12501290128471313.125012901284713

理解NumPy点积的形状

要预先确定点积的形状,请按照以下步骤进行:

步骤1:考虑两个数组“a”和“b”,以及它们各自的形状。

# 数组a和b的示例形状a_shape = (4, 3, 2)b_shape = (3, 2, 5)# 使用指定的形状创建随机数组a = np.random.random(a_shape)b =...