NEP 50 — Python 标量的晋升规则#
- 作者:
Sebastian Berg
- 状态:
最终
- 类型:
标准追踪
- 创建日期:
2021-05-25
摘要#
自 NumPy 1.7 以来,晋升规则采用了所谓的“安全类型转换”(safe casting),这依赖于对所涉及值的检查。这有助于为用户识别一些边界情况,但实现起来很复杂,也使得行为难以预测。
存在两种令人困惑的结果:
基于值的晋升意味着,例如 Python 整数的值可以决定由
np.result_type
确定的输出类型。np.result_type(np.int8, 1) == np.int8 np.result_type(np.int8, 255) == np.int16
这种逻辑出现的原因是
1
可以由uint8
或int8
表示,而255
不能由int8
表示,只能由uint8
或int16
表示。这在使用 0-D 数组(即“标量数组”)时也成立。
int64_0d_array = np.array(1, dtype=np.int64) np.result_type(np.int8, int64_0d_array) == np.int8
其中,
int64_0d_array
具有int64
dtype 的事实对结果 dtype 没有影响。在此示例中,dtype=np.int64
实际上被忽略了,因为只有它的值才重要。对于 Python 的
int
、float
或complex
,会像前面所示那样检查其值。但令人惊讶的是,当 NumPy 对象是 0-D 数组或 NumPy 标量时则 不会 检查。np.result_type(np.array(1, dtype=np.uint8), 1) == np.int64 np.result_type(np.int8(1), 1) == np.int64
原因是当所有对象都是标量或 0-D 数组时,基于值的晋升会被禁用。因此 NumPy 返回与
np.array(1)
相同的类型,这通常是int64
(这取决于系统)。
请注意,这些示例也适用于乘法、加法、比较等操作,以及它们对应的函数,如 np.multiply
。
本 NEP 建议根据以下两个指导原则重构行为:
值绝不能影响结果类型。
NumPy 标量和 0-D 数组的行为应与其 N-D 对应物保持一致。
我们建议移除所有基于值的逻辑,并为 Python 标量添加特殊处理,以保留一些便利行为。Python 标量将被视为“弱”类型。当 NumPy 数组/标量与 Python 标量结合时,它将被转换为 NumPy 的 dtype,这样:
np.array([1, 2, 3], dtype=np.uint8) + 1 # returns a uint8 array
np.array([1, 2, 3], dtype=np.float32) + 2. # returns a float32 array
将不再依赖 Python 值本身。
所提出的更改也适用于 np.can_cast(100, np.int8)
,然而,我们预计函数中的行为(晋升)在实践中将远比类型转换本身的变化更重要。
注意
截至 NumPy 1.24.x 系列,NumPy 已提供初步和有限的支持来测试此提案。
此外,还需要设置以下环境变量:
export NPY_PROMOTION_STATE=weak
有效值为 weak
、weak_and_warn
和 legacy
。请注意,weak_and_warn
实现了本 NEP 中提议的可选警告,预计会 非常 嘈杂。我们建议首先使用 weak
选项,并主要使用 weak_and_warn
来理解观察到的特定行为变化。
存在以下额外的 API:
np._set_promotion_state()
和np._get_promotion_state()
,它们等同于环境变量。(非线程/上下文安全。)with np._no_nep50_warning():
允许在使用weak_and_warn
晋升时抑制警告。(线程和上下文安全。)
目前,整数幂运算的溢出警告缺失。此外,np.can_cast
在 weak_and_warn
模式下无法发出警告。其对 Python 标量输入的行为可能仍在变化中(这应该只会影响极少数用户)。
新提案的晋升规则示意图#
更改后,NumPy 中的晋升将遵循以下示意图。晋升总是沿着绿线发生:在同类中从左到右,只有在必要时才晋升到更高一类。结果的种类始终是输入中最大的种类。请注意,float32
的精度低于 int32
或 uint32
,因此在示意图中稍微靠左排列。这是因为 float32
无法精确表示所有 int32
值。然而,出于实际原因,NumPy 允许将 int64
晋升为 float64
,实际上认为它们具有相同的精度。
Python 标量被插入到每种“类型”的最左侧,Python 整数不区分有符号和无符号。因此 NumPy 晋升使用以下有序的类型类别:
布尔
整型:有符号或无符号整数
不精确型:浮点数和复数浮点数
当将较低类型类别(布尔 < 整型 < 不精确型)的 Python 标量与较高类型类别进行晋升时,我们使用最小/默认精度:即 float64
、complex128
或 int64
(某些系统,例如 Windows,使用 int32
)。
请参阅下一节,其中包含阐明所提议行为的示例。与当前行为进行比较的更多示例可以在下表中找到。
新行为示例#
为了更容易理解上述文本和图示,我们提供了一些新行为的示例。在下面的示例中,Python 整数对结果类型没有影响。
np.uint8(1) + 1 == np.uint8(2)
np.int16(2) + 2 == np.int16(4)
在以下示例中,Python 的 float
和 complex
是“不精确的”,但 NumPy 值是整型,因此我们至少使用 float64
/complex128
。
np.uint16(3) + 3.0 == np.float64(6.0)
np.int16(4) + 4j == np.complex128(4+4j)
但这不会发生在 float 到 complex
的晋升中,因为 float32
和 complex64
具有相同的精度。
np.float32(5) + 5j == np.complex64(5+5j)
请注意,示意图省略了 bool
。它被置于“整型”之下,因此以下情况成立:
np.bool_(True) + 1 == np.int64(2)
True + np.uint8(2) == np.uint8(3)
请注意,尽管本 NEP 使用简单运算符作为示例,但所描述的规则通常适用于所有 NumPy 操作。
新旧行为对比表#
下表列出了相关更改和未更改的行为。有关导致“旧结果”的规则的详细说明,请参阅旧实现,有关新规则的详细说明,请参阅以下章节。向后兼容性部分讨论了这些更改可能如何影响用户。
请注意 0-D 数组(如 array(2)
)与非 0-D 数组(如 array([2])
)之间的重要区别。
表达式 |
旧结果 |
新结果 |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
未改变 |
|
|
未改变 |
|
|
未改变 [T3] |
|
|
异常 [T4] |
|
|
异常 [T5] |
|
|
|
|
|
|
|
|
未改变 |
|
|
|
|
|
未改变 |
|
|
|
|
|
|
|
|
|
|
|
未改变 [T12] |
新行为尊重 uint8
标量的 dtype。
当前 NumPy 在 0-D 数组或 NumPy 标量与数组结合时,会忽略它们的精度。
当前 NumPy 在 0-D 数组或 NumPy 标量与数组结合时,会忽略它们的精度。
旧行为使用 uint16
,因为 300
不适合 uint8
;新行为出于同样的原因会引发错误。
300
无法转换为 uint8
。
这可能是最危险的更改之一。保留类型会导致溢出。对于 NumPy 标量,会发出指示溢出的 RuntimeWarning
。
np.float32(3e100)
溢出为无穷大并发出警告。
1 + 1e-14
在 float32 中执行时会失去精度,但在 float64 中则不会。旧行为是根据数组的维度将标量参数转换为不同的 float32 或 float64;新行为下,计算总是以数组的精度进行(在本例中为 float32)。
NumPy 将 float32
和 int64
晋升为 float64
。旧行为在这里忽略了 int64
。
新行为在 array(3, complex64)
和 array([3], complex64)
之间保持一致:结果的 dtype 与数组参数的 dtype 相同。
新行为使用与数组参数 float32
兼容的复数 dtype 精度。
由于数组类型是整型,结果使用默认的复数精度,即 complex128
。
动机与范围#
改变 Python 标量和 NumPy 标量/0-D 数组值检查行为的动机有三:
NumPy 标量/0-D 数组的特殊处理以及值检查可能会让用户感到非常意外,
值检查逻辑更难解释和实现。通过 NEP 42 将其提供给用户定义的 DType 也更困难。目前,这导致了新旧(对值敏感)系统的双重实现。修复此问题将极大地简化内部逻辑并使结果更加一致。
它在很大程度上与 JAX 和 data-apis.org 等其他项目的选择保持一致(另请参阅 相关工作)。
我们相信,“弱”Python 标量的提案将通过提供清晰的心智模型来帮助用户理解操作将产生何种数据类型。该模型与 NumPy 当前经常遵循的数组精度保留以及原地操作的使用非常契合。
arr += value
只要不越过“类型”边界,就保留精度(否则会引发错误)。
尽管一些用户可能会怀念值检查行为,但即使在它看起来有用的情况下,它也很快导致意外。这可能是预期的:
np.array([100], dtype=np.uint8) + 1000 == np.array([1100], dtype=np.uint16)
但以下情况会令人意外:
np.array([100], dtype=np.uint8) + 200 == np.array([44], dtype=np.uint8)
考虑到该提案与原地操作数的行为一致,并避免了有时才能避免结果溢出的令人惊讶的行为切换,我们相信该提案遵循了“最少意外原则”。
用法与影响#
本 NEP 预计将在没有过渡期的情况下实施,不会对所有更改发出警告。这样的过渡期会产生许多(通常无害的)警告,这些警告很难被抑制。我们预计大多数用户将长期受益于更清晰的晋升规则,并且很少有人会直接(负面地)受到此更改的影响。然而,某些使用模式可能会导致问题性更改,这些将在向后兼容性部分详细说明。
解决此问题的方法将是一个可选的警告模式,能够通知用户潜在的行为变化。此模式预计会产生许多无害的警告,但如果观察到问题,则提供一种系统地审查代码并跟踪变化的方法。
对 can_cast
的影响#
can_cast 将不再检查值。因此,以下结果预计将从 True
变为 False
:
np.can_cast(np.int64(100), np.uint8)
np.can_cast(np.array(100, dtype=np.int64), np.uint8)
np.can_cast(100, np.uint8)
我们预计此更改的影响与以下更改相比将很小。
注意
最后一个输入是 Python 标量的示例 _可能_ 会被保留,因为 100
可以由 uint8
表示。
对涉及 NumPy 数组或标量的运算符和函数的影响#
对不涉及 Python 标量(float
、int
、complex
)的操作的主要影响是,对 0-D 数组和 NumPy 标量的操作将永远不依赖于它们的值。这消除了当前令人惊讶的情况。例如:
np.arange(10, dtype=np.uint8) + np.int64(1)
# and:
np.add(np.arange(10, dtype=np.uint8), np.int64(1))
将来会返回一个 int64
数组,因为 np.int64(1)
的类型将被严格遵守。目前返回的是一个 uint8
数组。
对涉及 Python int
、float
和 complex
的运算符的影响#
本 NEP 试图在处理字面值时保留旧行为的便利性。当前基于值的逻辑在涉及“无类型”的字面 Python 标量时,具有一些不错的特性:
np.arange(10, dtype=np.int8) + 1 # returns an int8 array
np.array([1., 2.], dtype=np.float32) * 3.5 # returns a float32 array
但当遇到“不可表示”的值时,却导致了意外:
np.arange(10, dtype=np.int8) + 256 # returns int16
np.array([1., 2.], dtype=np.float32) * 1e200 # returns float64
该提案旨在大部分保留这种行为。这通过在操作中将 Python 的 int
、float
和 complex
视为“弱”类型来实现。然而,为了避免意外,我们计划使向新类型的转换更加严格:前两个示例的结果将保持不变,但在第二个示例中,它将按以下方式更改:
np.arange(10, dtype=np.int8) + 256 # raises a TypeError
np.array([1., 2.], dtype=np.float32) * 1e200 # warning and returns infinity
第二个例子会发出警告,因为 np.float32(1e200)
会溢出为无穷大。然后它将照常使用 inf
进行计算。
其他库中的行为
在转换过程中溢出而非引发错误是一种选择;这是大多数 C 环境中的默认行为(类似于 NumPy C 也可以设置为因溢出而引发错误)。例如,这也是 pytorch
1.10 的行为。
Python 整数的特殊行为#
NEP 的晋升规则是根据结果 dtype 来阐述的,结果 dtype 通常也是操作 dtype(就结果精度而言)。这导致了 Python 整数可能看似是例外的情况:尽管 uint8(3) + 1000
必须被拒绝,因为在 uint8
中操作是不可能的,但 uint8(3) / 1000
返回一个 float64
,并且可以将两个输入转换为 float64
来找到结果。
在实践中,这意味着在以下情况下接受任意 Python 整数值:
NumPy 整数与 Python 整数之间的所有比较(
==
、<
等)始终是明确定义的。像
np.sqrt
这样给出浮点结果的一元函数可以并将 Python 整数转换为浮点数。整数除法通过将输入转换为
float64
返回浮点数。
请注意,可能还有其他函数可以应用这些例外,但目前尚未应用。在这些情况下,允许它们应被视为一种改进,但当用户影响较低时,为了简单起见,我们可能不会这样做。
向后兼容性#
通常,只使用默认 dtype float64 或 int32/int64 或更精确 dtype 的代码不应受到影响。
然而,所提出的更改将在许多混合了 0-D 或标量值(具有非默认 dtype)的情况下修改结果。在许多情况下,这些将是错误修复,但是,某些更改可能会给最终用户带来问题。
最重要的可能失败大概是以下示例:
arr = np.arange(100, dtype=np.uint8) # storage array with low precision
value = arr[10]
# calculation continues with "value" without considering where it came from
value * 100
以前,value * 100
会导致向上类型转换为 int32
/int64
(因为 value 是一个标量)。新行为将保留较低的精度,除非明确处理(就像 value
是一个数组一样)。这可能导致整数溢出,从而导致超出精度的错误结果。在许多情况下,这可能是静默的,尽管 NumPy 通常会为标量运算符发出警告。
类似地,如果存储数组是 float32
,则计算可能会保留较低的 float32
精度,而不是使用默认的 float64
。
还可能出现其他问题。例如:
混合精度时,浮点数比较,特别是相等性比较,可能会发生变化。
np.float32(1/3) == 1/3 # was False, will be True.
某些操作预计将开始失败
np.array([1], np.uint8) * 1000 np.array([1], np.uint8) == 1000 # possibly also
以保护用户,避免先前基于值的类型转换导致向上类型转换的情况。(当将
1000
转换为uint8
时发生失败。)在更奇怪的情况下可能会发生浮点溢出
np.float32(1e-30) * 1e50 # will return ``inf`` and a warning
因为
np.float32(1e50)
返回inf
。以前,即使1e50
不是 0-D 数组,这也会返回双精度结果。
在其他情况下,可能会出现精度增加。例如:
np.multiple(float32_arr, 2.)
float32_arr * np.float64(2.)
两者都将返回 float64 而不是 float32
。这提高了精度,但会稍微改变结果并使用双倍内存。
整数“精度阶梯”导致的更改#
当从 Python 整数创建数组时,NumPy 将按以下顺序尝试类型,结果取决于其值:
long (usually int64) → int64 → uint64 -> object
这与上述的晋升规则略有不同。
本 NEP 目前不包括更改此“阶梯”(尽管可能会在单独的文档中提出)。然而,在混合操作中,此“阶梯”将被忽略,因为值将被忽略。这意味着,操作将永远不会默默地使用 object
dtype。
np.array([3]) + 2**100 # Will error
用户将不得不编写以下代码之一:
np.array([3]) + np.array(2**100)
np.array([3]) + np.array(2**100, dtype=object)
因此,隐式转换为 object
的情况应该很少见,并且解决方法很明确,我们预计向后兼容性问题相当小。
详细描述#
以下提供了一些关于当前“基于值”的晋升逻辑,以及“弱标量”晋升及其内部处理方式的额外细节。
“基于值”晋升的旧实现#
本节回顾了当前基于值的逻辑在实践中是如何工作的,有关其有用性的示例,请参阅下一节。
当 NumPy 遇到“标量”值时(可以是 Python int、float、complex、NumPy 标量或数组):
1000 # Python scalar
int32(1000) # NumPy scalar
np.array(1000, dtype=int64) # zero dimensional
或者浮点/复数等效项,NumPy 将忽略 dtype 的精度,并找到能够容纳该值的最小可能 dtype。也就是说,它将尝试以下 dtype:
整型:
uint8
、int8
、uint16
、int16
、uint32
、int32
、uint64
、int64
。浮点型:
float16
、float32
、float64
、longdouble
。复数型:
complex64
、complex128
、clongdouble
。
请注意,例如对于整数值 10
,最小的 dtype 可以是 或 uint8
或 int8
。
当所有参数都是标量值时,NumPy 从未应用此规则。
np.int64(1) + np.int32(2) == np.int64(3)
对于整数,值是否适合精确取决于它是否可以由 dtype 表示。对于浮点数和复数,如果满足以下条件,则 dtype 被认为是足够的:
float16
:-65000 < 值 < 65000
(或 NaN/Inf)float32
:-3.4e38 < 值 < 3.4e38
(或 NaN/Inf)float64
:-1.7e308 < 值 < 1.7e308
(或 NaN/Inf)longdouble
: (最大范围,因此没有限制)
对于复数,这些界限应用于实部和虚部。这些值大致对应于 np.finfo(np.float32).max
。(然而,NumPy 从未强制对 float32(3.402e38)
的值使用 float64
,但对于 Python 值 3.402e38
则会。)
当前“基于值”晋升的状态#
在我们提出当前数据类型系统的替代方案之前,回顾“基于值的晋升”如何使用以及如何有用是很有帮助的。基于值的晋升允许以下代码正常工作:
# Create uint8 array, as this is sufficient:
uint8_arr = np.array([1, 2, 3], dtype=np.uint8)
result = uint8_arr + 4
result.dtype == np.uint8
result = uint8_arr * (-1)
result.dtype == np.int16 # upcast as little as possible.
其中特别是第一部分可能很有用:用户知道输入是一个具有特定精度的整数数组。考虑到简单的 + 4
能够保留之前的数据类型是直观的。将此示例替换为 np.float32
可能会更清晰,因为浮点数很少会溢出。如果没有这种行为,上述示例将需要编写 np.uint8(4)
,并且缺乏这种行为会使以下情况令人惊讶:
result = np.array([1, 2, 3], dtype=np.float32) * 2.
result.dtype == np.float32
在这种情况下,由于缺乏特殊处理,将返回 float64
。
值得注意的是,这种行为也适用于通用函数和零维数组:
# This logic is also used for ufuncs:
np.add(uint8_arr, 4).dtype == np.uint8
# And even if the other array is explicitly typed:
np.add(uint8_arr, np.array(4, dtype=np.int64)).dtype == np.uint8
回顾一下,如果我们将 4
替换为 [4]
使其成为一维数组,结果将不同:
# This logic is also used for ufuncs:
np.add(uint8_arr, [4]).dtype == np.int64 # platform dependent
# And even if the other array is explicitly typed:
np.add(uint8_arr, np.array([4], dtype=np.int64)).dtype == np.int64
提出的弱晋升规则#
本提案使用“弱标量”逻辑。这意味着 Python 的 int
、float
和 complex
不会被分配到典型的 dtype(如 float64 或 int64)。相反,它们被分配一个特殊的抽象 DType,类似于“标量”层次结构名称:Integral、Floating、ComplexFloating。
当发生晋升时(就像 ufunc 在没有精确循环匹配时那样),另一个 DType 能够决定如何处理 Python 标量。例如,UInt16
与 Integral
晋升时,结果将是 UInt16
。
注意
未来很可能会为用户定义的 DType 提供一个默认值。这很可能最终会是默认的整数/浮点数,但原则上可以实现更复杂的方案。
在任何时候,都不会使用值来决定此晋升的结果。仅当值转换为新的 dtype 时才考虑该值;这可能会引发错误。
实现#
实现本 NEP 需要向所有二元运算符(或 ufuncs)添加一些额外的机制,以便它们在可能的情况下尝试使用“弱”逻辑。有两种可能的方法:
二元运算符在此情况出现时,简单地尝试调用
np.result_type()
并将 Python 标量转换为结果类型(如果已定义)。二元运算符指示输入是 Python 标量,其余部分由 ufunc 调度/晋升机制处理(参见 NEP 42)。这提供了更大的灵活性,但需要在 ufunc 机制中添加一些额外的逻辑。
注意
截至目前,尚不清楚哪种方法更好,两者都将产生相当等效的结果,并且如果需要,将来方法 1 可以通过方法 2 进行扩展。
它还要求移除所有当前基于值的特殊代码路径。
不直观的是,实现中更大的一步可能是实现一个解决方案,允许在以下示例中引发错误:
np.arange(10, dtype=np.uint8) + 1000
即使 np.uint8(1000)
返回与 np.uint8(232)
相同的值。
注意
请参阅替代方案,我们可能最终决定这种静默溢出是可接受的,或者至少是另一个独立的问题。
替代方案#
有几个设计维度可以进行不同的选择。以下章节概述了这些。
使用强类型标量或两者混合#
解决基于值的晋升/类型转换问题的最简单方法是使用强类型 Python 标量,即 Python 浮点数被视为双精度,Python 整数始终被视为与默认整数 dtype 相同。
这将是最简单的解决方案,但是,它在使用 float32
或 int16
等数组时会导致许多向上类型转换。这些情况的解决方案将是依赖于原地操作。我们目前认为,尽管危险性较低,但此更改会影响许多用户,并且常常会令人惊讶(尽管期望值差异很大)。
原则上,弱行为与强行为无需统一。也可以让 Python 浮点数使用弱行为,而 Python 整数使用强行为,因为整数溢出远比浮点溢出更令人惊讶。
不在函数中使用弱标量逻辑#
本 NEP 提案的替代方案之一是将弱类型的使用范围缩小到 Python 运算符。
这有利有弊:
主要优点是,将其限制在 Python 运算符意味着这些“弱”类型/dtype 显然是短暂的,仅限于简短的 Python 语句。
缺点是
np.multiply
和*
的互换性较差。仅将“弱”晋升用于运算符意味着库无需担心是否要“记住”某个输入最初是 Python 标量。另一方面,它会增加对 Python 运算符稍有不同(或额外)逻辑的需求。(技术上,可能作为 ufunc 调度机制的一个标志来切换弱逻辑。)
__array_ufunc__
通常单独使用,为实现它的类数组(array-likes)提供 Python 运算符支持。如果运算符是特殊的,这些类数组可能需要一种机制来匹配 NumPy(例如,为 ufuncs 提供一个关键字参数以启用弱晋升)。
NumPy 标量可以特殊处理#
许多用户期望 NumPy 标量与 NumPy 数组不同,即 np.uint8(3) + 3
应该返回一个 int64
(或 Python 整数),而 uint8_arr + 3
则保留 uint8
dtype。
这种替代方案将非常接近 NumPy 标量的当前行为,但它会固化数组和标量之间的区别(NumPy 数组比 Python 标量“强”,但 NumPy 标量则不然)。
这种区分是完全可能的,然而,目前 NumPy 经常(并且静默地)将 0-D 数组转换为标量。因此,只有当我们也改变这种静默转换(有时称为“衰减”)行为时,考虑这种替代方案才有意义。
处理不安全的标量转换#
例如:
np.arange(10, dtype=np.uint8) + 1000
根据本 NEP,应引发错误。这可以放宽为发出警告,甚至忽略“不安全”的转换,这(在所有相关硬件上)将导致使用 np.uint8(1000) == np.uint8(232)
。
允许弱类型数组#
Python 标量是弱类型但数组不是弱类型的一个问题是,在许多情况下,np.asarray()
会不加区分地作用于输入。为了解决这个问题,JAX 会将 np.asarray(1)
的结果也视为弱类型。然而,这有两个难点:
JAX 注意到这可能会令人困惑:
np.broadcast_to(np.asarray(1), (100, 100))
是一个非 0-D 数组,它“继承”了弱类型。[2]
与 JAX 张量不同,NumPy 数组是可变的,因此赋值可能需要使其变为强类型?
一个标志很可能在实现细节中很有用(例如在 ufuncs 中),然而,目前我们不期望将其作为用户 API。主要原因是,如果这种标志作为函数结果传出,而不是仅在非常局部化的范围内使用,可能会让用户感到意外。
待办事项
在接受 NEP 之前,最好进一步讨论这个问题。库可能需要更清晰的模式来“传播”这种“弱”类型,这可能只是使用 np.asarray_or_literal()
来保留 Python 标量,或者是在调用 np.asarray()
之前调用 np.result_type()
的模式。
继续为 Python 标量使用基于值的逻辑#
当前逻辑的一些主要问题出现,是因为我们将其应用于 NumPy 标量和 0-D 数组,而不是应用于 Python 标量。因此我们可以考虑继续检查 Python 标量的值。
我们拒绝这个想法,理由是它不会消除前面提到的意外:
np.uint8(100) + 1000 == np.uint16(1100)
np.uint8(100) + 200 == np.uint8(44)
并且基于结果值而非输入值调整精度对于标量操作可能可行,但对于数组操作则不可行。这是因为数组操作需要在执行计算之前分配结果数组。
讨论#
参考文献和脚注#
版权#
本文档已置于公共领域。[1]