解码NumPy的点积:对维度魔法的简要探索
Exploring the Magic of Dimensionality in NumPy's Dot Product
澄清NumPy点积的混淆
介绍
在处理NumPy中的维度时,我是不是唯一一个经常感到困惑的人?今天,在阅读Gradio的文档页面时,我遇到了以下代码片段:
sepia_filter = np.array([ [0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131],])# input_img shape (H, W, 3)# sepia_filter shape (3, 3)sepia_img = input_img.dot(sepia_filter.T) # <- 为什么这是合法的?sepia_img /= sepia_img.max()
嘿,嘿,嘿!为什么图像(W,H,3)与滤波器(3,3)的点积是合法的?我让ChatGPT解释给我听,但它开始给我错误的答案(比如说这行不通)或者忽视我的问题,转而回答其他事情。所以,除了动动我的脑筋(再加上阅读文档,叹气),没有其他解决办法。
如果你对上面的代码也有点困惑,请继续阅读。
点积:一个通用例子
来自NumPy点积文档(有所修改):
如果a.shape = (I, J, C),b.shape = (K, C, L),那么dot(a, b)[i, j, k, l] = sum(a[i, j, :] * b[k, :, l])。请注意,“a”的最后一个维度等于“b”的倒数第二个维度。
或者,用代码表示:
I, J, K, L, C = 10, 20, 30, 40, 50a = np.random.random((I, J, C))b = np.random.random((K, C, L))c = a.dot(b)i, j, k, l = 3, 2, 4, 5print(c[i, j, k, l])print(sum(a[i, j, :] * b[k, :, l]))
输出(相同结果):
13.12501290128471313.125012901284713
理解NumPy点积的形状
要预先确定点积的形状,请按照以下步骤进行:
步骤1:考虑两个数组“a”和“b”,以及它们各自的形状。
# 数组a和b的示例形状a_shape = (4, 3, 2)b_shape = (3, 2, 5)# 使用指定的形状创建随机数组a = np.random.random(a_shape)b =...