取一个酉矩阵U
。我想交换列,使得每列的最大元素(绝对值)位于对角线上(模数关系)。在 numpy 中执行此操作的最佳方法是什么?
主页
/
user-8800836
Ben's questions
Ben
Asked:
2024-12-06 03:29:25 +0800 CST
我仍在掌握最佳实践jax
。我的主要问题如下:
实现自定义数组创建例程的最佳实践是什么jax
?
例如,我想实现一个函数,创建一个矩阵,其中除给定列中的 1 外,其他列均为 0。我选择了这个(Jupyter 笔记本):
import numpy as np
import jax.numpy as jnp
def ones_at_col(shape_mat, idx):
idxs = jnp.arange(shape_mat[1])[None,:]
mat = jnp.where(idx==idxs, 1, 0)
mat = jnp.repeat(mat, shape_mat[0], axis=0)
return mat
shape_mat = (5,10)
print(ones_at_col(shape_mat, 5))
%timeit np.zeros(shape_mat)
%timeit jnp.zeros(shape_mat)
%timeit ones_at_col(shape_mat, 5)
输出为
[[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]]
127 ns ± 0.717 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
31.3 µs ± 331 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
123 µs ± 1.79 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
我的功能比常规功能慢了 4 倍jnp.zeros()
,这还不算太糟糕。这说明我做的事情并不疯狂。
但是这两个jax
例程都比等效例程慢得多numpy
。这些函数无法进行 jitted,因为它们将形状作为参数,因此无法跟踪。我猜这就是它们天生就慢的原因?我猜如果它们中的任何一个出现在另一个 jitted 函数的范围内,它们可以被跟踪并加速?
我能做得更好吗?或者我是否正在突破可能的极限jax
?
Ben
Asked:
2024-11-26 08:30:40 +0800 CST
我有一个零维numpy
标量s
和一个二维numpy
矩阵m
。我想形成一个向量矩阵,其中的所有元素都m
与配对,s
如下例所示:
import numpy as np
s = np.asarray(5)
m = np.asarray([[1,2],[3,4]])
# Result should be as follows
array([[[5, 1],
[5, 2]],
[[5, 3],
[5, 4]]])
换句话说,我想np.asarray([s, m])
在 的最低级别上逐元素地矢量化操作。对于内的m
任何多维数组,是否有一种明显的方法来做到这一点?m
numpy
我确信这个在某个地方,但我无法用语言表达,也找不到它。如果你能找到它,请随时将我重定向到那里。