文章目录
创建Tensor
- from numpy, list
- zeros, ones
- fill
- random
- constant
- Application
From Numpy,List
import tensorflow as tf
import numpy as np
numpy创建数据并转为Tensor
tf.convert_to_tensor(np.ones([2,3]))
<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[1., 1., 1.],
[1., 1., 1.]])>
tf.convert_to_tensor(np.zeros([2,3]))
<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[0., 0., 0.],
[0., 0., 0.]])>
list 创建数据并转为Tensor
tf.convert_to_tensor([1,2])
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>
tf.convert_to_tensor([1,2.])
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>
tf.convert_to_tensor([[1], [2.]])
<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[1.],
[2.]], dtype=float32)>
tf.zeros
# 标量
tf.zeros([])
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
# 1维
tf.zeros([1])
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>
# 2维
tf.zeros([2,2])
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 0.],
[0., 0.]], dtype=float32)>
# 3维
tf.zeros([2,3,3])
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]], dtype=float32)>
tf.zeros_like
a = tf.zeros([2,3,3])
a
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]], dtype=float32)>
# 根据传递进来的对象创建相同的shape数组
tf.zeros_like(a)
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]], dtype=float32)>
a.shape
TensorShape([2, 3, 3])
# 与tf.zeros_like等同
tf.zeros(a.shape)
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]], dtype=float32)>
tf.ones
# 1维
tf.ones(1)
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>
# 标量
tf.ones([])
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
tf.ones([2])
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>
tf.ones([2,3])
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 1.],
[1., 1., 1.]], dtype=float32)>
tf.ones_like(a)
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]], dtype=float32)>
# 等同tf.ones_like
tf.ones(a.shape)
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]], dtype=float32)>
tf.fill(自定义填充)
tf.fill([2,2],value=0)
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[0, 0],
[0, 0]], dtype=int32)>
tf.fill([2,2],value=0.)
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 0.],
[0., 0.]], dtype=float32)>
tf.fill([2,2],value=1)
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 1],
[1, 1]], dtype=int32)>
tf.fill([2,2],value=9)
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[9, 9],
[9, 9]], dtype=int32)>
tf.random.normal(服从指定正态分布的序列)
# 用于从“服从指定正态分布的序列”中随机取出指定个数的值
# mean: 正态分布的平均值
# stddev: 正态分布的标准差
# mean: 正态分布的平均值
# seed: 用于为分发创建随机种子
tf.random.normal([2,2],mean=1,stddev=1)
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 1.2933502 , -0.03046083],
[-0.20453262, 1.8908376 ]], dtype=float32)>
tf.random.normal([2,2])
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 0.3923446, -1.0170017],
[-2.2753727, 1.1124496]], dtype=float32)>
# 截断正态分布
# 从截断的正态分布中输出随机值,虽然同样是输出正态分布,但是它生成的值是在距离均值两个标准差范围之内的
tf.random.truncated_normal([2,2],mean=0,stddev=1)
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.09197889, 0.7197461 ],
[-1.2158926 , 0.64242524]], dtype=float32)>
tf.random.uniform(均匀分布)
# 0-1之间均匀采样
tf.random.uniform([2,2],minval=0,maxval=1)
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0.59650195, 0.8330668 ],
[0.3798474 , 0.04617167]], dtype=float32)>
# 0-100之间均匀采样
tf.random.uniform([3,4],minval=0,maxval=100)
<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[39.938354 , 20.946777 , 88.78066 , 65.06464 ],
[67.6425 , 53.750908 , 63.23085 , 1.2730479],
[53.01423 , 93.832565 , 41.489197 , 97.90436 ]], dtype=float32)>
Random Permutation(随机排列)
# 生成数据
idx=tf.range(10)
idx
<tf.Tensor: shape=(10,), dtype=int32, numpy=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)>
# 打散数据
idx=tf.random.shuffle(idx)
idx
<tf.Tensor: shape=(10,), dtype=int32, numpy=array([8, 0, 5, 2, 4, 1, 9, 3, 7, 6], dtype=int32)>
# 正太分布数据
a=tf.random.normal([10,784])
# 均匀分布数据
b=tf.random.uniform([10], minval=0,maxval=10,dtype=tf.int32)
# 从params的axis维根据indices的参数值获取切片
a=tf.gather(a,indices=idx)
b=tf.gather(b,indices=idx)
a,b
(<tf.Tensor: shape=(10, 784), dtype=float32, numpy=
array([[ 1.2042824 , -1.3816946 , -0.34105217, ..., 1.7581135 ,
-0.0409344 , -0.7366836 ],
[ 0.23826773, 0.7524342 , 1.1266931 , ..., -0.23611061,
0.03666063, 0.92886865],
[-0.6004217 , -0.10280953, 0.3276803 , ..., -0.5391498 ,
-0.08050299, 0.42966053],
...,
[-0.9661765 , 0.38981438, -1.3240955 , ..., 0.27162325,
0.7889554 , -0.28455386],
[-0.8107048 , 0.18297605, 1.7718701 , ..., 0.1255566 ,
0.8605448 , -0.01122306],
[ 1.3737792 , 0.00462327, -0.12949888, ..., 1.4805819 ,
-0.79014236, 0.8239969 ]], dtype=float32)>,
<tf.Tensor: shape=(10,), dtype=int32, numpy=array([0, 2, 6, 5, 7, 6, 8, 0, 2, 1], dtype=int32)>)
tf.random.constant
# 从类似张量的对象创建常量张量
tf.constant(1)
<tf.Tensor: shape=(), dtype=int32, numpy=1>
tf.constant(1.)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
tf.constant([1])
<tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>
tf.constant([1,2.])
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>
# tf.constant([[1,2.],[3.]])
tf.constant([[1,2.],[3.,2]])
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 2.],
[3., 2.]], dtype=float32)>
Typical Dim Data(典型尺寸数据)
- []
- 标量 Scalar
- 1,2,2.3
- [d]
- [h,w]
- [b.len,vec]
- [b,h,w,c]
- [t,b,h,w,c]
- …
Scalar 标量
- []
- 1,2,2.3
- loss=mse(out,y)
- accuracy
LOSS
# 生成均匀分布数据
out=tf.random.uniform([4,10])
out
<tf.Tensor: shape=(4, 10), dtype=float32, numpy=
array([[0.7957759 , 0.6470387 , 0.85942817, 0.9195243 , 0.80446124,
0.9332032 , 0.10627711, 0.0725143 , 0.88076556, 0.8592241 ],
[0.92300546, 0.9134604 , 0.79779387, 0.44595373, 0.3783729 ,
0.5998908 , 0.33701885, 0.00334263, 0.1695168 , 0.24933684],
[0.91918683, 0.2754066 , 0.7670957 , 0.5798063 , 0.11099887,
0.94225526, 0.90335643, 0.06025386, 0.1178335 , 0.60193515],
[0.17284477, 0.09774017, 0.3873669 , 0.99980617, 0.3310578 ,
0.60174 , 0.7005192 , 0.09624684, 0.8688313 , 0.3498112 ]],
dtype=float32)>
# 生成标签数据
y=tf.range(4)
y
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>
# 热编码
y=tf.one_hot(y,depth=10)
y
<tf.Tensor: shape=(4, 10), dtype=float32, numpy=
array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)>
l o s s = ∑ ( y − o u t ) 2 loss=\sum_{}^{}(y-out)^2 loss=∑(y−out)2
# 计算标签和预测之间的均方误差
loss=tf.keras.losses.mse(y,out)
loss
<tf.Tensor: shape=(4,), dtype=float32, numpy=array([0.5093101 , 0.24023107, 0.34072328, 0.20383982], dtype=float32)>
# 计算张量维度上元素的平均值。
loss=tf.reduce_mean(loss)
loss
<tf.Tensor: shape=(), dtype=float32, numpy=0.32352605>
Vector
- Bias
- [out_dim]
- y=X@W + b(bias)
from tensorflow.keras import layers, optimizers, datasets
# dense:全连接层
# 相当于添加一个层,即初学的add_layer()函数
# 8 -> 10
net=layers.Dense(10)
net.build((4,8))
net.kernel
<tf.Variable 'kernel:0' shape=(8, 10) dtype=float32, numpy=
array([[ 0.01621449, -0.14240894, -0.39780322, 0.12109792, 0.5365114 ,
0.48510587, 0.10659033, -0.10367435, -0.16618788, -0.28067237],
[ 0.4738747 , 0.23055995, -0.28443712, -0.45606828, -0.2740961 ,
-0.22440037, -0.179019 , -0.09799156, -0.5101716 , -0.12497845],
[ 0.13550311, -0.30709794, -0.43564707, 0.5684767 , 0.02459109,
0.40579963, 0.41654438, -0.12095284, 0.5667572 , 0.3747304 ],
[ 0.11231458, 0.48574615, 0.24569553, 0.49175394, -0.23326564,
-0.18449727, -0.2591746 , 0.38994485, 0.47833967, 0.42860043],
[-0.01137727, 0.39716053, 0.11606997, 0.57487595, -0.09682512,
-0.25918382, 0.04865551, -0.02905029, -0.54025316, -0.39565778],
[-0.18153116, -0.32627657, 0.53890324, -0.14481756, 0.5542736 ,
0.2505871 , -0.15414259, 0.43023145, -0.01220387, -0.15518737],
[ 0.04331303, 0.4358071 , -0.18711427, -0.24738216, -0.02527547,
0.47772157, 0.35470355, 0.23595202, 0.28886396, -0.16550332],
[-0.53486687, -0.16493705, -0.44315812, 0.11054051, 0.38065422,
-0.14305219, 0.5382327 , -0.0840193 , 0.46640956, -0.10600576]],
dtype=float32)>
# b
net.bias
<tf.Variable 'bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>
Matrix
- input x: [b, vec_dim]
- weight: [input_dim, output_dim]
$
W=\begin{bmatrix}
w1,1w1,2…w1R \
w2,1w2,2…w2R \
… \
ws,1ws,2…wsR
\end{bmatrix}
$
x=tf.random.normal([4,784])
x.shape
TensorShape([4, 784])
# 784 -> 10
net=layers.Dense(10)
net.build((4,784))
net(x).shape
TensorShape([4, 10])
net.kernel.shape
TensorShape([784, 10])
net.bias.shape
TensorShape([10])
Dim=3 Tensor
- x: [b, seq_len, word_dim]
- [b, 5, 5]
# 加载数据集
(X_train,y_train),(X_test,y_test)=tf.keras.datasets.imdb.load_data(num_words=10000)
# 样本长度
print(X_train[0])
print(len(X_train[0]))
print(type(X_train[0]))
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]
218
<class 'list'>
X_train.shape
(25000,)
pad_sequences 序列预处理
- sequences:浮点数或整数构成的两层嵌套列表
- maxlen:None或整数,为序列的最大长度。大于此长度的序列将被截短,小于此长度的序列将在后部填0.在命名实体识别任务中,主要是指句子的最大长度
- dtype:返回的numpy array的数据类型
- padding:‘pre’或‘post’,确定当需要补0时,在序列的起始还是结尾补
- truncating:‘pre’或‘post’,确定当需要截断序列时,从起始还是结尾截断
- value:浮点数,此值将在填充时代替默认的填充值0
# 序列预处理
x_train=tf.keras.preprocessing.sequence.pad_sequences(X_train,maxlen=80)
print(x_train[0])
print(len(x_train[0]))
print(type(x_train[0]))
[ 15 256 4 2 7 3766 5 723 36 71 43 530 476 26
400 317 46 7 4 2 1029 13 104 88 4 381 15 297
98 32 2071 56 26 141 6 194 7486 18 4 226 22 21
134 476 26 480 5 144 30 5535 18 51 36 28 224 92
25 104 4 226 65 16 38 1334 88 12 16 283 5 16
4472 113 103 32 15 16 5345 19 178 32]
80
<class 'numpy.ndarray'>
x_train.shape
(25000, 80)
# 编码
emb=embedding(x-train)
emb.shape
out=tf.rnn(emb[:4])
out.shape
Dim=4 Tensor
- Image: [b, h, w, 3]
- [b,28,28,1]
- feature maps: [b, h, w, c]
- [b, h, w, 3]
x=tf.random.normal((4,32,32,3))
# 卷积层
net=layers.Conv2D(16,kernel_size=3)
net(x)
<tf.Tensor: shape=(4, 30, 30, 16), dtype=float32, numpy=
array([[[[ 5.11408783e-02, -2.11419791e-01, 4.78755325e-01, ...,
-7.09839880e-01, -1.62664950e-01, -3.69613528e-01],
[ 1.04834512e-01, 4.45832849e-01, 3.67929399e-01, ...,
-5.73501050e-01, -9.47547019e-01, -2.73244321e-01],
[ 6.64293393e-02, -3.11614126e-01, 8.88543963e-01, ...,
-5.91010869e-01, 8.83403420e-02, -6.25845611e-01],
...,
[ 7.06125498e-01, -3.92723858e-01, 8.19511190e-02, ...,
3.88073295e-01, -3.65344107e-01, -4.27012980e-01],
[ 1.42910075e+00, -3.73372495e-01, 1.46203363e+00, ...,
-2.29273185e-01, 1.61577657e-01, -3.62703115e-01],
[ 1.52406737e-01, 9.83081281e-01, -9.29397166e-01, ...,
-4.01978552e-01, -1.65519774e+00, -6.72560513e-01]],
[[ 3.38849723e-01, 5.57779014e-01, -4.73486066e-01, ...,
6.66084468e-01, -2.25408077e-01, 2.01416627e-01],
[ 9.98801887e-02, 3.51594418e-01, -1.17291820e+00, ...,
-4.78315353e-01, 1.71240568e-01, 2.29796067e-01],
[ 1.28986561e+00, 4.30635571e-01, 2.68138438e-01, ...,
3.85428101e-01, -1.30183734e-02, 7.93949187e-01],
...,
[ 8.19463372e-01, 4.61028397e-01, 2.90335417e-02, ...,
-2.65352666e-01, -9.04486954e-01, 4.05555129e-01],
[ 2.95304179e-01, 7.62058318e-01, -6.67063117e-01, ...,
-2.71042168e-01, 6.85480416e-01, 2.40263894e-01],
[ 3.03411335e-01, -7.25541830e-01, -8.30628812e-01, ...,
-7.50274420e-01, -1.50119960e-01, -3.95896226e-01]],
[[-2.56495386e-01, -7.38008507e-03, -1.67473584e-01, ...,
2.62954921e-01, -8.31158459e-01, -9.43797290e-01],
[-3.24434787e-01, 7.00764358e-01, -7.61871561e-02, ...,
-1.63662970e-01, -2.82550633e-01, 9.41351712e-01],
[-4.85321730e-01, -2.02089727e-01, -5.27621686e-01, ...,
2.42007345e-01, 2.84019470e-01, 2.84560677e-02],
...,
[-2.54487991e-01, 3.96051168e-01, -6.56044304e-01, ...,
-5.79203784e-01, -7.97876954e-01, -6.80033207e-01],
[-4.36108887e-01, -4.78615195e-01, 1.12478447e+00, ...,
-1.31961375e-01, 1.66062370e-01, -5.76823093e-02],
[-1.49053439e-01, 1.66869834e-01, -5.55442452e-01, ...,
3.28702837e-01, -1.06465340e+00, 2.32470438e-01]],
...,
[[ 3.62408757e-01, -1.23813343e+00, 7.57653639e-02, ...,
3.63634303e-02, 8.34123254e-01, -1.68675110e-01],
[ 3.52181286e-01, 9.89539027e-01, -1.13494694e+00, ...,
1.14149503e-01, -9.32057440e-01, -4.54324901e-01],
[ 5.59438944e-01, -9.75797400e-02, -2.96774745e-01, ...,
-9.09146965e-02, -1.40010104e-01, -2.78197140e-01],
...,
[-6.19658291e-01, -1.24590620e-01, -4.91972193e-02, ...,
-2.09599838e-01, 2.36650333e-01, -2.43948981e-01],
[ 4.20804381e-01, -3.13818097e-01, 6.72074437e-01, ...,
1.59845605e-01, -1.21049285e-01, 5.39864361e-01],
[-1.12884209e-01, -9.41875696e-01, 3.67895700e-02, ...,
-4.34767634e-01, 1.29640436e+00, 2.41676748e-01]],
[[ 6.18064642e-01, 5.64845145e-01, 4.26679671e-01, ...,
6.53841555e-01, -7.71579087e-01, 6.57868683e-02],
[ 8.64152551e-01, -5.17681301e-01, 1.43626824e-01, ...,
-1.70546189e-01, 8.96106884e-02, 4.54999506e-01],
[-4.78937477e-01, -4.56122994e-01, 6.41848207e-01, ...,
-1.19053459e+00, 1.20916940e-01, 2.71652102e-01],
...,
[ 9.20986295e-01, 7.50426650e-01, 5.38934469e-01, ...,
4.02694009e-02, -6.34055018e-01, 1.87944248e-02],
[-1.02507330e-01, 1.09432077e+00, 9.18149129e-02, ...,
2.04167053e-01, -1.80424526e-01, -2.74709523e-01],
[-5.92706978e-01, -4.26084399e-01, -6.30943716e-01, ...,
-2.03500003e-01, 5.43525159e-01, -7.80534931e-03]],
[[ 3.69665354e-01, 5.48572429e-02, 7.47180283e-01, ...,
-4.25332904e-01, 3.04142982e-01, 6.80612326e-02],
[-2.93423682e-01, 2.24400461e-01, -1.75468147e-01, ...,
-1.21893203e+00, 2.17650369e-01, 5.83284616e-01],
[-4.92672741e-01, 7.17395246e-02, -1.09146468e-01, ...,
-3.73993725e-01, 4.24882956e-02, 9.17238295e-02],
...,
[ 6.62829280e-02, -1.58003271e-01, -1.24654186e+00, ...,
-2.55260974e-01, 1.85372248e-01, 2.66077727e-01],
[-1.95675135e-01, 5.16129769e-02, -3.96091282e-01, ...,
5.76506853e-01, -4.20435786e-01, -3.86470616e-01],
[ 3.98471832e-01, 2.92219281e-01, 4.01777089e-01, ...,
1.58537793e+00, -9.73537207e-01, 2.93004245e-01]]],
[[[-4.88410741e-01, 3.35286170e-01, 2.36457393e-01, ...,
1.26244557e+00, -3.53360593e-01, -4.18140233e-01],
[-3.03056449e-01, 7.47030601e-02, 5.35628617e-01, ...,
3.22452217e-01, -1.48300707e-01, 2.03683227e-01],
[ 8.70249152e-01, 6.30708635e-01, -1.95097223e-01, ...,
-8.11083496e-01, -5.93419313e-01, -1.81713790e-01],
...,
[-8.20982456e-02, 5.62634587e-01, -2.71823734e-01, ...,
8.18851683e-03, -3.14377576e-01, -9.49833021e-02],
[ 3.77315246e-02, 2.37164915e-01, -1.12752795e+00, ...,
6.82983696e-01, -2.03109041e-01, 3.27775270e-01],
[-3.42391729e-01, 4.12799299e-01, -1.03908813e+00, ...,
-1.77690387e-01, -1.95123062e-01, 7.12375492e-02]],
[[-6.04353905e-01, -4.86152679e-01, 7.74208426e-01, ...,
-3.77850741e-01, 6.32216828e-03, -2.98906416e-01],
[ 3.03069472e-01, 1.45399764e-01, 2.10653007e-01, ...,
9.63516235e-01, 3.69627088e-01, 2.55155087e-01],
[ 2.18334988e-01, 6.62136912e-01, -1.25356126e+00, ...,
-2.10385859e-01, 6.05388105e-01, 5.24035275e-01],
...,
[ 7.21290410e-01, -1.96054131e-01, -2.53071606e-01, ...,
9.81350616e-03, -3.36887777e-01, -8.73216391e-01],
[ 7.26956069e-01, -2.55486459e-01, 1.13994825e+00, ...,
-7.81499624e-01, -9.61165130e-01, -3.59825432e-01],
[ 1.11502385e+00, -7.02142417e-01, 1.41868222e+00, ...,
-9.77690816e-01, -2.14777321e-01, 4.68507677e-01]],
[[ 3.55722100e-01, 4.81817633e-01, 4.14835602e-01, ...,
5.47380567e-01, 5.61807156e-01, 9.30055007e-02],
[-5.89577481e-02, 3.67752552e-01, -2.10168734e-01, ...,
-1.17150187e+00, -1.30239666e-01, -7.32884407e-01],
[ 7.63650179e-01, -7.12301672e-01, 1.14955544e+00, ...,
-1.16633463e+00, 5.95600665e-01, 9.20139626e-02],
...,
[ 5.23808539e-01, -1.08179323e-01, 3.20617974e-01, ...,
-7.01221347e-01, 3.13195109e-01, 6.03635669e-01],
[-1.10597581e-01, 6.05741262e-01, -3.21878254e-01, ...,
-9.76488471e-01, 8.60475838e-01, -3.45011771e-01],
[ 1.34457603e-01, 6.90276265e-01, -4.44975734e-01, ...,
-7.25347996e-01, 1.29672837e+00, 8.13677251e-01]],
...,
[[-3.14656645e-01, 4.40141082e-01, -1.73647985e-01, ...,
-9.47391808e-01, -7.14678824e-01, -1.45510659e-01],
[ 2.90451586e-01, 4.85799819e-01, 3.68700057e-01, ...,
3.72695416e-01, 9.44734871e-01, 1.65997118e-01],
[ 9.50033963e-02, -7.15577960e-01, -1.43527782e+00, ...,
-3.88930261e-01, -6.76119626e-01, -9.19058025e-01],
...,
[ 4.16554570e-01, 8.66417825e-01, -4.78255659e-01, ...,
-4.81267542e-01, 2.48430237e-01, 7.90232182e-01],
[ 4.94117826e-01, -8.06771696e-01, 7.47416198e-01, ...,
-7.30220854e-01, 1.18371248e+00, -5.31985700e-01],
[ 7.23220631e-02, -2.15284348e-01, -1.92686126e-01, ...,
-8.46448898e-01, -5.97360777e-03, -4.28196698e-01]],
[[-2.25013599e-01, -3.48351777e-01, -5.54229319e-01, ...,
-1.17885478e-01, 6.86471090e-02, -7.48188347e-02],
[-6.46478087e-02, -7.97209620e-01, 8.16888452e-01, ...,
1.15718633e-01, 1.01453686e+00, -2.39759922e-01],
[-5.25327086e-01, 5.88182211e-01, -9.02729273e-01, ...,
7.62883961e-01, -1.38593817e+00, -4.31858391e-01],
...,
[ 7.87195325e-01, -2.07848758e-01, -1.93571016e-01, ...,
7.28979111e-01, 5.80614686e-01, 3.75502616e-01],
[-2.03104734e-01, 7.44872868e-01, 1.03454679e-01, ...,
-1.42501622e-01, 2.46956795e-01, 1.22676164e-01],
[ 7.62553990e-01, -5.06239831e-01, 1.93014458e-01, ...,
5.26253104e-01, 2.82648057e-01, 8.88811767e-01]],
[[-1.04494788e-01, -6.20330453e-01, -2.88904518e-01, ...,
-9.90672529e-01, 2.64258742e-01, 2.14836478e-01],
[ 3.59708428e-01, -9.44501907e-02, 9.55824256e-01, ...,
-9.02668655e-01, 1.85327336e-01, 1.01691764e-03],
[-3.08903068e-01, 1.56039998e-01, 6.49919152e-01, ...,
3.52934897e-01, -4.29365396e-01, -1.05594814e+00],
...,
[ 5.94020307e-01, 3.79775882e-01, 8.62827957e-01, ...,
-8.00387859e-01, -5.81361890e-01, 8.48889947e-02],
[-2.51786768e-01, -3.84436883e-02, 3.43284369e-01, ...,
1.34449214e-01, 1.51265606e-01, -3.25558543e-01],
[-5.40423274e-01, 8.11421871e-01, -1.04027975e+00, ...,
1.75916687e-01, -3.55451852e-01, -4.18996513e-01]]],
[[[ 4.03016716e-01, -9.42383185e-02, 3.89042109e-01, ...,
-6.66733503e-01, 7.01597154e-01, 9.04736102e-01],
[-8.01422894e-01, 9.61738050e-01, -1.64667815e-01, ...,
-5.52261829e-01, -6.96025670e-01, -1.16403568e+00],
[-3.79517712e-02, 1.41728476e-01, -3.10020149e-01, ...,
1.44183421e+00, 5.36346138e-02, -1.90599233e-01],
...,
[-5.14048576e-01, -2.21568689e-01, 5.31760991e-01, ...,
-1.03457296e+00, 2.44072124e-01, -5.40369947e-04],
[-5.53254724e-01, 1.03404298e-01, -8.44387431e-03, ...,
5.00614405e-01, 2.78823525e-01, 2.91976810e-01],
[-6.95665359e-01, 5.69382131e-01, -1.19335508e+00, ...,
-2.95902044e-01, 3.53888348e-02, 5.01627335e-03]],
[[ 1.20825207e+00, -3.29682589e-01, -3.54656160e-01, ...,
-8.61604214e-02, -2.49787807e-01, 5.63491404e-01],
[ 6.02723479e-01, -8.93569291e-02, 5.49378812e-01, ...,
5.55969119e-01, -4.93603408e-01, 1.42602459e-01],
[ 4.12537068e-01, -8.80664825e-01, 6.61745310e-01, ...,
-7.31622279e-01, -7.21969008e-01, -1.03079927e+00],
...,
[-6.26841545e-01, -3.40548962e-01, 3.20449062e-02, ...,
-7.87570834e-01, 9.43054736e-01, -1.60972886e-02],
[ 8.26191306e-01, -5.19545734e-01, 5.71801364e-01, ...,
2.79021472e-01, -3.34796578e-01, -8.44013095e-01],
[ 3.60502213e-01, 4.38319683e-01, 2.13991180e-01, ...,
-6.36544347e-01, -5.06267473e-02, -5.51769555e-01]],
[[-8.10957193e-01, 1.13378942e+00, -1.47388566e+00, ...,
2.66789883e-01, 7.00945854e-02, 4.77081358e-01],
[ 2.62345165e-01, -4.42802548e-01, 1.32407904e-01, ...,
-1.38010776e+00, -1.99280515e-01, 8.50553453e-01],
[ 6.35858774e-01, 1.65538931e+00, -2.15716064e-01, ...,
-5.80295384e-01, -5.62521100e-01, 7.77037382e-01],
...,
[ 5.82785085e-02, -4.98133540e-01, -3.92823726e-01, ...,
5.99813223e-01, 5.84782600e-01, 1.82665259e-01],
[ 1.19108522e+00, 4.52362090e-01, -2.45918483e-02, ...,
-2.15084866e-01, -2.15071514e-01, 1.52354971e-01],
[ 5.26105404e-01, 2.52852559e-01, 3.06775961e-02, ...,
-1.40375543e+00, 1.26784432e+00, 4.97334570e-01]],
...,
[[ 7.99124658e-01, 6.19785070e-01, -2.05564678e-01, ...,
8.10738921e-01, -1.40358889e+00, 1.31820530e-01],
[-4.84217465e-01, -4.65817660e-01, -1.86626345e-01, ...,
-2.83500314e-01, -1.27768144e-01, -8.91454667e-02],
[ 3.24485391e-01, -8.31413984e-01, 2.12071046e-01, ...,
-3.96526724e-01, -7.63818026e-01, -4.76207286e-01],
...,
[-2.97311485e-01, -1.00890353e-01, -4.08464611e-01, ...,
-3.07795316e-01, 3.18488568e-01, 3.35744619e-01],
[-4.40074831e-01, 4.96284395e-01, -5.66527307e-01, ...,
1.37263402e-01, -6.25700802e-02, -2.86034942e-01],
[-2.77784169e-01, -2.70140618e-01, 9.17928815e-02, ...,
1.35965988e-01, 2.35500276e-01, -2.41373680e-04]],
[[-3.21114600e-01, 3.98868248e-02, 3.31562501e-03, ...,
-1.06464684e+00, -1.07360594e-02, 1.61534380e-02],
[-5.94698310e-01, -7.75814295e-01, -8.55177194e-02, ...,
6.45053446e-01, -5.99738657e-02, 3.18236798e-01],
[ 5.18429339e-01, 6.83645844e-01, 5.31058192e-01, ...,
-6.40040517e-01, -2.74345487e-01, 6.71724200e-01],
...,
[-2.59246260e-01, -9.52060878e-01, 1.99519381e-01, ...,
-7.06413329e-01, 5.07094800e-01, 1.99018568e-01],
[-6.45661831e-01, -2.64860898e-01, 4.11892146e-01, ...,
-2.21386373e-01, -1.03395559e-01, -4.46663290e-01],
[ 7.03365728e-02, -1.71302333e-01, 6.37728274e-01, ...,
5.99841118e-01, -3.00094217e-01, -1.68467209e-01]],
[[-1.63223729e-01, 5.45847535e-01, 8.18708718e-01, ...,
-7.14481771e-01, 1.95999846e-01, -2.45239045e-02],
[-8.57500196e-01, -1.14799410e-01, -3.29935670e-01, ...,
1.40195891e-01, -1.25240743e+00, -1.42649800e-01],
[ 3.91972065e-01, 2.59413898e-01, 3.64314884e-01, ...,
8.55417967e-01, 1.25591683e+00, 5.47372282e-01],
...,
[ 9.93716195e-02, 7.77123153e-01, 8.93170595e-01, ...,
3.03051144e-01, 7.42947683e-02, 7.33308613e-01],
[ 3.11880819e-02, -2.95177728e-01, 2.06190944e-01, ...,
1.53311670e+00, 4.61079597e-01, 3.99204530e-02],
[-5.59754610e-01, 2.39114627e-01, -1.36300159e+00, ...,
-4.93013799e-01, -7.12665096e-02, 7.62722492e-02]]],
[[[-7.07550526e-01, -3.03240091e-01, 6.60755455e-01, ...,
-7.90750161e-02, -3.65321159e-01, -4.30483460e-01],
[-1.86911270e-01, 2.39894748e-01, -9.40757334e-01, ...,
3.57000470e-01, -2.11175028e-02, -1.62767589e-01],
[ 5.51772773e-01, 1.32477924e-01, 4.75931138e-01, ...,
6.03891313e-01, 3.50777097e-02, -2.31402159e-01],
...,
[ 3.60434175e-01, -2.40440086e-01, 1.36279285e-01, ...,
8.35516527e-02, 8.36447299e-01, -5.07031977e-02],
[-2.87243992e-01, 5.44507802e-01, -7.85944313e-02, ...,
-3.60828161e-01, 5.22470474e-02, 2.05641940e-01],
[-6.79957271e-01, -2.88661897e-01, -6.68929160e-01, ...,
-8.08060616e-02, 3.17035764e-01, 3.03957164e-02]],
[[-1.76532626e-01, 5.74718297e-01, -6.78593397e-01, ...,
7.33274281e-01, -1.62762314e-01, -3.74133945e-01],
[ 6.47938073e-01, -8.46699327e-02, -2.87556887e-01, ...,
5.72368860e-01, 2.43657812e-01, 6.66255951e-01],
[ 4.25566167e-01, -6.37027442e-01, 2.88296103e-01, ...,
-5.30690730e-01, 5.80780327e-01, 2.29448169e-01],
...,
[ 1.95712298e-01, 3.00102949e-01, -6.41289175e-01, ...,
7.47099340e-01, -5.37887156e-01, -3.03677708e-01],
[ 5.38146235e-02, -4.94787879e-02, 1.87186092e-01, ...,
6.59565553e-02, -3.63303535e-02, -2.02135652e-01],
[ 1.06148578e-01, 2.69699812e-01, -1.33452356e-01, ...,
7.64168203e-01, -9.29973841e-01, 6.16707318e-02]],
[[ 5.97640455e-01, 3.70381474e-01, -1.05833687e-01, ...,
1.04272676e+00, -4.54665720e-01, -7.94478431e-02],
[ 2.62423337e-01, 2.42638126e-01, 3.63827407e-01, ...,
-1.93119079e-01, 6.85722753e-02, 3.13739687e-01],
[-2.03977063e-01, -8.53810668e-01, 1.41490296e-01, ...,
-7.12618008e-02, 2.41974235e-01, 1.75281800e-02],
...,
[-3.56253654e-01, -2.64765304e-02, -1.70551449e-01, ...,
-1.18611500e-01, -4.53533858e-01, -1.33774772e-01],
[ 2.49976337e-01, 6.65123165e-02, 3.74623120e-01, ...,
1.98802855e-02, 2.49585107e-01, 4.90735173e-01],
[ 1.00049064e-01, 1.08251944e-01, -4.02322143e-01, ...,
5.90216041e-01, -9.96080697e-01, 2.94348657e-01]],
...,
[[-2.33743116e-01, 3.79105168e-03, 1.00059927e+00, ...,
-3.50684762e-01, -2.03996941e-01, 5.43174446e-01],
[-1.31496513e+00, -7.68835992e-02, 9.21740979e-02, ...,
1.82122439e-01, 3.92907321e-01, 4.70603764e-01],
[-1.91511422e-01, -3.12377542e-01, -2.02807784e-01, ...,
4.74840961e-02, 3.90858687e-02, 5.20986259e-01],
...,
[ 3.92100602e-01, -2.56313950e-01, 4.40316498e-01, ...,
-7.95090020e-01, 9.58722889e-01, 6.68692514e-02],
[-2.37561002e-01, -2.45921806e-01, -4.79896784e-01, ...,
3.05862486e-01, 4.18518394e-01, -7.09717393e-01],
[ 9.63735729e-02, 5.26934683e-01, -2.43123412e-01, ...,
-2.29806781e-01, 8.54906738e-01, 1.45679355e-01]],
[[-6.50098324e-01, 1.00708961e+00, 4.81430953e-03, ...,
-4.03477103e-01, 5.49487919e-02, -5.28567910e-01],
[-7.36448050e-01, 2.37548098e-01, 6.89756870e-01, ...,
1.44244343e-01, 4.32573399e-03, -1.50717488e-02],
[-3.26949120e-01, 1.15860653e+00, -4.18980062e-01, ...,
6.63382530e-01, 2.57124096e-01, 5.86564839e-01],
...,
[ 5.46750352e-02, 3.76966506e-01, 1.79619983e-01, ...,
7.27700114e-01, 7.69920886e-01, 3.30334604e-01],
[-3.29259604e-01, -7.28786111e-01, 4.86835659e-01, ...,
3.71421605e-01, 6.50456697e-02, -1.75632238e-01],
[-2.41953239e-01, -5.99552631e-01, 2.57653624e-01, ...,
3.01961869e-01, 9.32633221e-01, -5.71823642e-02]],
[[ 1.64552823e-01, -1.26616490e+00, 1.83328986e-01, ...,
7.92825967e-02, 9.05444622e-01, 2.93801337e-01],
[ 1.69471517e-01, 1.89377040e-01, -3.94473910e-01, ...,
6.73406303e-01, -2.55624741e-01, -5.90663373e-01],
[-2.28453115e-01, 1.53474778e-01, 7.06688702e-01, ...,
1.02944827e+00, -9.44124758e-02, -4.60808694e-01],
...,
[ 2.38406166e-01, -1.16364264e+00, 9.53676283e-01, ...,
5.98508120e-01, 8.32367301e-01, 2.21683785e-01],
[-2.71952063e-01, -6.16223276e-01, 4.03875411e-01, ...,
2.75555700e-02, 1.19690061e-01, 1.28521711e-01],
[ 1.00875773e-01, -5.86154759e-01, 5.06292939e-01, ...,
1.41296101e+00, 1.24331325e-01, -2.65127104e-02]]]],
dtype=float32)>
Dim=5 Tensor
- Single task: [b, h, w, 3]
- meta-learning:
- [task_b, b, h, w, 3]