屏蔽数组#

你会做什么#

使用 NumPy 中的屏蔽数组模块来分析 COVID-19 数据并处理缺失值。

你将学到什么#

  • 您将了解什么是掩码数组以及如何创建它们

  • 您将了解如何访问和修改屏蔽数组的数据

  • 您将能够决定何时在某些应用程序中适合使用掩码数组

你需要什么#

  • 对 Python 有基本的了解。如果您想加深记忆,请查看Python 教程

  • 基本熟悉 NumPy

  • 要在计算机上运行绘图,您需要matplotlib


什么是掩码数组?#

考虑以下问题。您的数据集缺少条目或条目无效。如果您正在对此数据进行任何类型的处理,并且想要跳过或标记这些不需要的条目而不只是删除它们,则可能必须使用条件或以某种方式过滤数据。numpy.ma模块提供了一些与NumPy ndarray相同的功能,并添加了结构以确保在计算中不使用无效条目。

来自参考指南

掩码数组是标准numpy.ndarraymask的组合。掩码可以是nomask,指示关联数组中没有无效值,也可以是布尔值数组,用于确定关联数组的每个元素的值是否有效。当掩码的元素为 时False,关联数组的相应元素有效,并且被称为未掩码。当掩码的元素为 时True,关联数组的相应元素被称为被掩码(无效)。

我们可以将MaskedArray视为以下各项的组合:

  • numpy.ndarray数据,作为任何形状或数据类型的常规;

  • 与数据形状相同的布尔掩码;

  • A fill_value,一个可用于替换无效条目以返回标准的值numpy.ndarray

它们什么时候有用?#

在某些情况下,屏蔽数组比仅仅消除数组的无效条目更有用:

  • 当您想要保留屏蔽的值以供以后处理,而不复制数组时;

  • 当您必须处理许多数组时,每个数组都有自己的掩码。如果掩码是数组的一部分,则可以避免错误并且代码可能会更紧凑;

  • 当您对缺失值或无效值有不同的标志,并且希望保留这些标志而不在原始数据集中替换它们,但将它们排除在计算之外时;

  • 如果您无法避免或消除缺失值,但又不想在运算中处理NaN(非数字)值。

屏蔽数组也是一个好主意,因为该模块还附带了大多数NumPy 通用函数 (ufunc)numpy.ma的特定实现,这意味着您仍然可以对屏蔽数据应用快速向量化函数和操作。然后输出是一个掩码数组。我们将在下面看到一些在实践中如何运作的示例。

使用屏蔽数组查看 COVID-19 数据#

可以从Kagglewho_covid_19_sit_rep_time_series.csv下载包含 2020 年初 COVID-19 爆发初始数据的数据集。我们将查看文件中包含的一小部分数据。(请注意,该文件已在 2020 年底的某个时间被替换为没有丢失数据的版本。)

import numpy as np
import os

# The os.getcwd() function returns the current folder; you can change
# the filepath variable to point to the folder where you saved the .csv file
filepath = os.getcwd()
filename = os.path.join(filepath, "who_covid_19_sit_rep_time_series.csv")

数据文件包含不同类型的数据,组织如下:

  • 第一行是标题行,(主要)描述下面各行中每列中的数据,从第四列开始,标题是观察的日期。

  • 第二行到第七行包含的摘要数据与我们要检查的数据类型不同,因此我们需要将其从我们将使用的数据中排除。

  • 我们希望处理的数值数据从第 4 列第 8 行开始,并从那里延伸到最右边的列和最下面的行。

让我们探索该文件中前 14 天记录的数据。为了从文件中收集数据.csv,我们将使用numpy.genfromtxt函数,确保我们仅选择具有实际数字的列,而不是包含位置数据的前四列。我们还跳过该文件的前 6 行,因为它们包含我们不感兴趣的其他数据。另外,我们将提取有关该数据的日期和位置的信息。

# Note we are using skip_header and usecols to read only portions of the
# data file into each variable.
# Read just the dates for columns 4-18 from the first row
dates = np.genfromtxt(
    filename,
    dtype=np.str_,
    delimiter=",",
    max_rows=1,
    usecols=range(4, 18),
    encoding="utf-8-sig",
)
# Read the names of the geographic locations from the first two
# columns, skipping the first six rows
locations = np.genfromtxt(
    filename,
    dtype=np.str_,
    delimiter=",",
    skip_header=6,
    usecols=(0, 1),
    encoding="utf-8-sig",
)
# Read the numeric data from just the first 14 days
nbcases = np.genfromtxt(
    filename,
    dtype=np.int_,
    delimiter=",",
    skip_header=6,
    usecols=range(4, 18),
    encoding="utf-8-sig",
)

numpy.genfromtxt函数调用中,我们为每个数据子集选择了numpy.dtype(整数 - numpy.int_- 或字符串 - numpy.str_)。我们还使用参数encoding来选择utf-8-sig文件的编码(在官方 Python 文档中阅读有关编码的更多信息。您可以从参考文档基本 IO 教程numpy.genfromtxt中阅读有关该函数的更多信息。

探索数据#

首先,我们可以绘制我们拥有的整套数据,看看它是什么样子的。为了获得可读的图,我们仅选择几个日期来显示在x 轴刻度中。另请注意,在绘图命令中,我们使用nbcases.T(数组的转置nbcases),因为这意味着我们将文件的每一行绘制为单独的行。我们选择绘制虚线(使用'--'线条样式)。有关详细信息,请参阅matplotlib文档。

import matplotlib.pyplot as plt

selected_dates = [0, 3, 11, 13]
plt.plot(dates, nbcases.T, "--")
plt.xticks(selected_dates, dates[selected_dates])
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
Text(0.5, 1.0, 'COVID-19 cumulative cases from Jan 21 to Feb 3 2020')
../_images/20fe53540422fbc1ad5507ed954165f8e94e27e8e2b7f1459aa64880e40d2a1b.png

从 1 月 24 日到 2 月 1 日,该图的形状很奇怪。知道这些数据来自哪里会很有趣。如果我们查看locations从文件中提取的数组.csv,我们可以看到有两列,其中第一列包含区域,第二列包含国家/地区名称。但是,只有前几行包含第一列的数据(中国的省份名称)。之后,我们就只有国家名称了。因此,将来自中国的所有数据分组到一行中是有意义的。为此,我们将从数组中nbcases仅选择数组的第二个条目locations对应于中国的行。接下来,我们将使用numpy.sum函数对所有选定的行 ( axis=0) 求和。另请注意,第 35 行对应于每个日期的整个国家/地区的总计数。由于我们想根据省份数据自行计算总和,因此我们必须首先从 和 中删除该locationsnbcases

totals_row = 35
locations = np.delete(locations, (totals_row), axis=0)
nbcases = np.delete(nbcases, (totals_row), axis=0)

china_total = nbcases[locations[:, 1] == "China"].sum(axis=0)
china_total
array([  247,   288,   556,   817,   -22,   -22,   -15,   -10,    -9,
          -7,    -4, 11820, 14410, 17237])

这个数据有问题——我们不应该在累积数据集中有负值。这是怎么回事?

缺失数据

查看数据,我们发现:有一个时期缺少数据

nbcases
array([[  258,   270,   375, ...,  7153,  9074, 11177],
       [   14,    17,    26, ...,   520,   604,   683],
       [   -1,     1,     1, ...,   422,   493,   566],
       ...,
       [   -1,    -1,    -1, ...,    -1,    -1,    -1],
       [   -1,    -1,    -1, ...,    -1,    -1,    -1],
       [   -1,    -1,    -1, ...,    -1,    -1,    -1]])

我们看到的所有-1值都来自numpy.genfromtxt尝试从原始文件中读取丢失的数据.csv。显然,我们不想计算丢失的数据-1,因为我们只是想跳过这个值,这样它就不会干扰我们的分析。导入numpy.ma模块后,我们将创建一个新数组,这次屏蔽无效值:

from numpy import ma

nbcases_ma = ma.masked_values(nbcases, -1)

如果我们查看nbcases_ma掩码数组,这就是我们所拥有的:

nbcases_ma
masked_array(
  data=[[258, 270, 375, ..., 7153, 9074, 11177],
        [14, 17, 26, ..., 520, 604, 683],
        [--, 1, 1, ..., 422, 493, 566],
        ...,
        [--, --, --, ..., --, --, --],
        [--, --, --, ..., --, --, --],
        [--, --, --, ..., --, --, --]],
  mask=[[False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [ True, False, False, ..., False, False, False],
        ...,
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]],
  fill_value=-1)

我们可以看到这是一种不同类型的数组。正如简介中提到的,它具有三个属性(datamaskfill_value)。请记住,该mask属性具有与无效True数据对应的元素值(在属性中由两个破折号表示)。data

让我们尝试看看排除第一行(来自中国湖北省的数据)后的数据是什么样子,以便我们可以更仔细地查看缺失的数据:

plt.plot(dates, nbcases_ma[1:].T, "--")
plt.xticks(selected_dates, dates[selected_dates])
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
Text(0.5, 1.0, 'COVID-19 cumulative cases from Jan 21 to Feb 3 2020')
../_images/80ec8b4f7c2bbc7c0911404bda4e039db7b851dbe636f2529588dbf9dca5afb0.png

现在我们的数据已被屏蔽,让我们尝试总结一下中国的所有案例:

china_masked = nbcases_ma[locations[:, 1] == "China"].sum(axis=0)
china_masked
masked_array(data=[278, 309, 574, 835, 10, 10, 17, 22, 23, 25, 28, 11821,
                   14411, 17238],
             mask=[False, False, False, False, False, False, False, False,
                   False, False, False, False, False, False],
       fill_value=999999)

请注意,这china_masked是一个掩码数组,因此它具有与常规 NumPy 数组不同的数据结构。现在,我们可以使用.data属性直接访问其数据:

china_total = china_masked.data
china_total
array([  278,   309,   574,   835,    10,    10,    17,    22,    23,
          25,    28, 11821, 14411, 17238])

这样更好:不再有负值。然而,我们仍然可以看到,有些天来,累计病例数似乎在下降(例如从835例减少到10例),这与“累计数据”的定义并不相符。如果我们仔细观察数据,我们会发现,在中国大陆数据缺失的时期,中国香港、台湾、澳门和“未指定”地区都有有效数据。也许我们可以从中国的病例总数中删除这些数据,以便更好地理解数据。

首先,我们将确定中国大陆的位置索引:

china_mask = (
    (locations[:, 1] == "China")
    & (locations[:, 0] != "Hong Kong")
    & (locations[:, 0] != "Taiwan")
    & (locations[:, 0] != "Macau")
    & (locations[:, 0] != "Unspecified*")
)

现在,china_mask是一个布尔值数组(TrueFalse);我们可以使用屏蔽数组的ma.nonzero方法检查索引是否是我们想要的:

china_mask.nonzero()
(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 31, 33]),)

现在我们可以正确地对中国大陆的条目进行求和:

china_total = nbcases_ma[china_mask].sum(axis=0)
china_total
masked_array(data=[278, 308, 440, 446, --, --, --, --, --, --, --, 11791,
                   14380, 17205],
             mask=[False, False, False, False,  True,  True,  True,  True,
                    True,  True,  True, False, False, False],
       fill_value=999999)

我们可以用这些信息替换数据并绘制一个新的图表,重点关注中国大陆:

plt.plot(dates, china_total.T, "--")
plt.xticks(selected_dates, dates[selected_dates])
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China")
Text(0.5, 1.0, 'COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China')
../_images/c49dff14e4a270d1efb32cc1ec64d03cd8149605e411938745172eece55effd6.png

很明显,屏蔽数组是这里的正确解决方案。我们无法在不错误描述曲线演变的情况下代表缺失的数据。

拟合数据#

我们能想到的一种可能性是对缺失的数据进行插值来估计一月下旬的病例数。观察我们可以使用.mask属性选择屏蔽元素:

china_total.mask
invalid = china_total[china_total.mask]
invalid
masked_array(data=[--, --, --, --, --, --, --],
             mask=[ True,  True,  True,  True,  True,  True,  True],
       fill_value=999999,
            dtype=int64)

我们还可以通过使用此掩码的逻辑非来访问有效条目:

valid = china_total[~china_total.mask]
valid
masked_array(data=[278, 308, 440, 446, 11791, 14380, 17205],
             mask=[False, False, False, False, False, False, False],
       fill_value=999999)

现在,如果我们想为此数据创建一个非常简单的近似值,我们应该考虑无效条目周围的有效条目。首先,我们选择数据有效的日期。请注意,我们可以使用掩码数组中的掩码china_total来索引日期数组:

dates[~china_total.mask]
array(['1/21/20', '1/22/20', '1/23/20', '1/24/20', '2/1/20', '2/2/20',
       '2/3/20'], dtype='<U7')

最后,我们可以使用 numpy.polynomial 包的拟合功能来创建尽可能适合数据的三次多项式模型:

t = np.arange(len(china_total))
model = np.polynomial.Polynomial.fit(t[~china_total.mask], valid, deg=3)
plt.plot(t, china_total)
plt.plot(t, model(t), "--")
[<matplotlib.lines.Line2D at 0x7f5de63c96f0>]
../_images/7b859464438ee1bc6a50b003743165e8cf499cc92c8a218ae7c3ab5a6ae9d6af.png

该图不太可读,因为线条似乎相互重叠,因此让我们用更详细的图进行总结。我们将绘制可用的真实数据,并显示不可用数据的三次拟合,使用此拟合来计算 2020 年 1 月 28 日(记录开始后 7 天)观察到的病例数的估计值:

plt.plot(t, china_total)
plt.plot(t[china_total.mask], model(t)[china_total.mask], "--", color="orange")
plt.plot(7, model(7), "r*")
plt.xticks([0, 7, 13], dates[[0, 7, 13]])
plt.yticks([0, model(7), 10000, 17500])
plt.legend(["Mainland China", "Cubic estimate", "7 days after start"])
plt.title(
    "COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China\n"
    "Cubic estimate for 7 days after start"
)
Text(0.5, 1.0, 'COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China\nCubic estimate for 7 days after start')
../_images/615121c367e83e8952cb9b82b45ead01f710d82913813c98c977b08ec2392e68.png

在实践中

  • 添加-1缺失数据不是问题numpy.genfromtxt;在这种特殊情况下,用 替换缺失值0可能很好,但稍后我们会看到这远非通用解决方案。此外,还可以numpy.genfromtxt使用参数来调用函数usemask。如果usemask=Truenumpy.genfromtxt自动返回一个掩码数组。

进一步阅读#

本教程未涵盖的主题可以在文档中找到:

参考

  • Ensheng Dong, Hongru Du, Lauren Gardner,实时跟踪 COVID-19 的基于网络的交互式仪表板,《柳叶刀传染病》,第 20 卷,第 5 期,2020 年,第 533-534 页,ISSN 1473-3099,https:/ /doi.org/10.1016/S1473-3099(20)30120-1