NumPy squeeze()

squeeze() 方法用于移除数组中尺寸大小为 1 的维度。

示例

import numpy as np

# create a 3-D array
array1 = np.array([[[0, 1]]])

# squeeze the array squeezedArray = np.squeeze(array1)
print(squeezedArray) # Output : [0 1]

在此,array1 是一个 3 维数组,具有两个单例维度(尺寸大小为1的维度)。因此,这两个单例维度被移除,array1 从三维被压缩到一维。


squeeze() 语法

squeeze() 的语法是

numpy.squeeze(array, axis = None)

squeeze() 参数

squeeze() 方法接受两个参数

  • array - 要压缩的数组
  • axis(可选) - 沿其压缩数组的轴(Noneinttuple

squeeze() 返回值

squeeze() 方法返回压缩后的数组。


示例 1:压缩具有单维条目的数组

import numpy as np

array1 = np.array([[[1, 2, 3]]])

# squeeze the array squeezedArray = np.squeeze(array1)
print(squeezedArray)

输出

[1 2 3]

示例 2:压缩具有多个单维条目的数组

import numpy as np

array1 = np.array([[1], [2], [3]]) 

# squeeze the array squeezedArray = np.squeeze(array1)
print(squeezedArray)

输出

[1 2 3]

示例 3:沿特定轴进行压缩

如果不传递 axis 参数,则默认为 None,所有长度为 1 的维度都会被移除。

但是,我们可以指定要压缩的特定轴。

import numpy as np
array1 = np.array([[[1], [2], [3]]])

print('Original Array: \n', array1, "\nShape: ",array1.shape, '\n')

# squeeze array1
array2 = np.squeeze(array1)  

print('Squeezed Array: \n', array2, "\nShape: ",array2.shape, '\n')

# squeeze array1 along axis 0 array3 = np.squeeze(array1, axis = 0)
print('Squeezed Array along axis 0: \n', array3, "\nShape: ",array3.shape, '\n')
# squeeze array1 along the last axis array4 = np.squeeze(array1, axis = -1)
print('Squeezed Array along last axis: \n', array4, "\nShape: ",array4.shape, '\n')
# squeeze array1 along axis 9 and 2 array5 = np.squeeze(array1, axis = (0, 2))
print('Squeezed Array along axis (0, 2): \n', array5, "\nShape: ",array5.shape, '\n')

输出

Original Array: 
[[[1]
  [2]
  [3]]] 
Shape:  (1, 3, 1) 

Squeezed Array: 
 [1 2 3] 
Shape:  (3,) 

Squeezed Array along axis 0: 
 [[1]
 [2]
 [3]] 
Shape:  (3, 1) 

Squeezed Array along last axis: 
 [[1 2 3]] 
Shape:  (1, 3) 

Squeezed Array along axis (0, 2): 
 [1 2 3] 
Shape:  (3,) 

示例 4:所有维度长度都为 1 的情况下的压缩

如果所有维度长度都为 1,则返回一个标量值。

import numpy as np

array1 = np.array([[[123]]])

# squeeze array1
array2 = np.squeeze(array1)  

print('Squeezed Array: \n', array2)

输出

123 

注意: 尽管 123 是一个标量值,但它仍然被视为一个数组。例如,

print(type(array2)) #<class 'numpy.ndarray'>

我们的高级学习平台,凭借十多年的经验和数千条反馈创建。

以前所未有的方式学习和提高您的编程技能。

试用 Programiz PRO
  • 交互式课程
  • 证书
  • AI 帮助
  • 2000+ 挑战