李林超博客
首页
归档
留言
友链
动态
关于
归档
留言
友链
动态
关于
首页
NLP
正文
06.Pytorch张量索引操作
Leefs
2024-12-29 AM
1412℃
0条
[TOC] ### 一、简介 #### 1.1 基本概念 张量索引是根据张量的位置或值选择特定元素或子集的过程。PyTorch 张量索引提供了一组丰富的索引操作,可以使用不同的索引方案选择和修改张量元素。 在 PyTorch 中,张量是一个多维数组,可以存储不同类型和大小的数值数据。可以使用一个或多个索引对张量进行索引,这些索引指定元素沿张量每个维度的位置。索引张量将返回一个包含所选元素的新张量或具有修改元素的原始张量的视图。 #### 1.2 索引类型 张量索引主要有以下几种类型: + **整型索引(Integer Indexing)**: + 使用一个或多个整型值来访问张量中的特定元素或子张量。 + 例如,对于一个形状为(3, 4)的二维张量A,A[1, 2]访问的是第二行第三列的元素(索引从0开始)。 + **切片索引(Slicing)**: - 使用冒号`:`表示范围,来访问张量的子集。 - 例如,对于一个形状为(3, 4)的二维张量`A`,`A[1:3, :]`访问的是第二行和第三行的所有列。 + **高级索引(Advanced Indexing)**: + 使用数组或张量作为索引,以访问特定模式的元素。 + 例如,对于一个形状为(3, 4)的二维张量`A`,`A[[0, 2], [1, 3]]`访问的是(0,1)和(2,3)位置的元素。 + **布尔索引(Boolean Indexing)**: + 使用布尔数组或张量来筛选符合条件的元素。 + 例如,对于一个形状为(3, 4)的二维张量`A`,`A[A > 5]`访问的是所有大于5的元素。 #### 1.3 摘要 | **索引类型** | **描述** | **代码方法** | | ------------ | -------------------------------------------------- | ------------------------------------------------------------ | | **整型索引** | 使用一个或多个整型值来访问张量中的特定元素或子张量 | `tensor[index]` 访问张量 `tensor` 的特定元素或子张量,例如 `tensor[1, 2]` 访问二维张量的第二行第三列元素 | | **切片索引** | 使用冒号 `:` 表示范围,访问张量的子集 | `tensor[start:stop:step]` 访问 `tensor` 的子集,例如 `tensor[1:3, :]` 访问二维张量的第二行和第三行的所有列 | | **高级索引** | 使用数组或张量作为索引,以访问特定模式的元素 | `tensor[indices]` 访问 `tensor` 的特定模式元素,例如 `tensor[[1, 2], [2, 3]]` 访问特定位置的元素 | | **布尔索引** | 使用布尔数组作为掩码,访问满足条件的元素 | `tensor[mask]` 访问 `tensor` 中满足 `mask` 条件的元素,例如 `tensor[tensor > 0]` 访问所有大于0的元素 | ### 二、索引使用 #### 2.1 整数索引 整数索引是张量索引的最基本形式,它允许使用张量沿每个维度的整数位置来选择特定元素。可以使用整数索引来选择单个元素或子张量,方法是提供与所需元素的索引相对应的整数列表。 **代码示例** ```python # 固定随机数种子 torch.manual_seed(0) data = torch.randint(0, 10, [4, 5]) # 四行五列的二维张量 print(data) print('-' * 30) print(data[2]) # 获取第三行数据,返回一维张量 print('-' * 30) print(data[:, 1]) # 获取第二列数据,返回一维张量 print('-' * 30) print(data[1, 2]) # 获取第二行的第三列数据,返回零维张量 print('-' * 30) print(data[1][2]) # 同上 ``` **运行结果** ``` tensor([[4, 9, 3, 0, 3], [9, 7, 3, 7, 3], [1, 6, 6, 9, 8], [6, 6, 8, 4, 3]]) ------------------------------ tensor([1, 6, 6, 9, 8]) ------------------------------ tensor([9, 7, 6, 6]) ------------------------------ tensor(3) ------------------------------ tensor(3) ``` #### 2.2 列表索引 在`PyTorch`和其他基于`NumPy`的库中,列表索引是一种高级索引方式,允许使用列表或数组中的索引来选择张量中的特定元素。 这种索引方式可以比传统的切片更灵活、更强大。 **基本概念:** - 列表索引:使用一个或多个列表来索引张量,提取出特定的元素。 - 广播机制:当列表索引的形状不匹配时,会自动 **`扩展(广播)`** 这些索引,使得它们可以一起使用。 **代码示例** ```python # 固定随机数种子 torch.manual_seed(0) data = torch.randint(0, 10, [4, 5]) # 四行五列的二维张量 print(data) print('-' * 30) print(data[[1, 0, 2]]) # 返回下标为1行、0行、2行共三行数据组成的3行5列的二维张量 print('-' * 30) print(data[[0, 1, 3], [3, 2, 4]]) # 返回下标为0行3列、1行2列、3行4列三个数据组成的一维张量 print('-' * 30) print(data[[[0], [1]], [[3], [4]]]) # 返回下标为0行3列、1行4列两个数据组成的2行1列的二维张量 print('-' * 30) print(data[[0, 1], [[3], [4]]]) # 返回下标为0行3列、1行3列、0行4列、1行4列四个数据组成的2行2列的二维张量 print('-' * 30) print(data[[0, 1], [[1, 2], [0, 4]]]) # 返回下标为0行1列、1行2列、0行0列、1行4列四个数据组成的2行2列的二维张量 print('-' * 30) print(data[[[1], [0]], [3, 4]]) # 返回下标为1行3列、1行4列、0行3列、0行4列四个数据组成的2行2列的二维张量 print('-' * 30) print(data[[[1, 3], [0, 2]], [3, 4]]) # 返回下标为1行3列、3行4列、0行3列、2行4列四个数据组成的2行2列的二维张量 ``` **运行结果** ``` tensor([[4, 9, 3, 0, 3], [9, 7, 3, 7, 3], [1, 6, 6, 9, 8], [6, 6, 8, 4, 3]]) ------------------------------ tensor([[9, 7, 3, 7, 3], [4, 9, 3, 0, 3], [1, 6, 6, 9, 8]]) ------------------------------ tensor([0, 3, 3]) ------------------------------ tensor([[0], [3]]) ------------------------------ tensor([[0, 7], [3, 3]]) ------------------------------ tensor([[9, 3], [4, 3]]) ------------------------------ tensor([[7, 3], [0, 3]]) ------------------------------ tensor([[7, 3], [0, 8]]) ``` #### 2.3 范围索引 **代码示例** ```py # 固定随机数种子 torch.manual_seed(0) data = torch.randint(0, 10, [4, 5]) # 四行五列的二维张量 print(data) print('-' * 30) print(data[:3, 4]) # 返回前三行的第五列数据组成的一维张量 print('-' * 30) print(data[:3, [0, 2, 4]]) # 返回前三行的第一三五列数据组成的二维张量 print('-' * 30) print(data[:3, :4]) # 返回前三行的前四列数据组成的二维张量 print('-' * 30) print(data[2:, :4]) # 返回第三行到末行的前四列数据组成的二维张量 ``` **运行结果** ``` tensor([[4, 9, 3, 0, 3], [9, 7, 3, 7, 3], [1, 6, 6, 9, 8], [6, 6, 8, 4, 3]]) ------------------------------ tensor([3, 3, 8]) ------------------------------ tensor([[4, 3, 3], [9, 3, 3], [1, 6, 8]]) ------------------------------ tensor([[4, 9, 3, 0], [9, 7, 3, 7], [1, 6, 6, 9]]) ------------------------------ tensor([[1, 6, 6, 9], [6, 6, 8, 4]]) ``` #### 2.4 布尔索引 布尔索引允许根据布尔条件选择张量的特定元素。可以使用布尔索引来选择满足特定条件的元素或屏蔽不满足条件的元素。 **代码示例** ```python # 固定随机数种子 torch.manual_seed(0) data = torch.randint(0, 10, [4, 5]) # 四行五列的二维张量 print(data) print('-' * 30) print(data[data > 5]) # 返回所有大于5的元素组成的一维张量 print('-' * 30) print(data[[True, False, True, False]]) # 返回第一行与第三行数据组成的二维张量 print('-' * 30) print(data[1:, [True, False, True, False, True]]) # 返回第二行到末行的第一三五列数据组成的二维张量 print('-' * 30) print(data[data[:, 2] > 5]) # 返回第三列大于5的行数据组成的二维张量 print('-' * 30) print(data[:, data[1] > 5]) # 返回第二行大于5的列数据组成的二维张量 ``` **运行结果** ``` tensor([[4, 9, 3, 0, 3], [9, 7, 3, 7, 3], [1, 6, 6, 9, 8], [6, 6, 8, 4, 3]]) ------------------------------ tensor([9, 9, 7, 7, 6, 6, 9, 8, 6, 6, 8]) ------------------------------ tensor([[4, 9, 3, 0, 3], [1, 6, 6, 9, 8]]) ------------------------------ tensor([[9, 3, 3], [1, 6, 8], [6, 8, 3]]) ------------------------------ tensor([[1, 6, 6, 9, 8], [6, 6, 8, 4, 3]]) ------------------------------ tensor([[4, 9, 0], [9, 7, 7], [1, 6, 9], [6, 6, 4]]) ``` #### 2.5 多维索引 **代码示例** ```python # 固定随机数种子 torch.manual_seed(0) data = torch.randint(0, 10, [3, 4, 5]) # 三片四行五列的三维张量 print(data) print('-' * 30) print(data[0, :, :]) # 返回第一片所有数据,四行五列的二维张量 print('-' * 30) print(data[:, 0, :]) # 返回所有片的第一行数据,三行五列的二维张量 print('-' * 30) print(data[:, :, 0]) # 返回所有片的第一列数据,三行四列的二维张量 ``` **运行结果** ``` tensor([[[4, 9, 3, 0, 3], [9, 7, 3, 7, 3], [1, 6, 6, 9, 8], [6, 6, 8, 4, 3]], [[6, 9, 1, 4, 4], [1, 9, 9, 9, 0], [1, 2, 3, 0, 5], [5, 2, 9, 1, 8]], [[8, 3, 6, 9, 1], [7, 3, 5, 2, 1], [0, 9, 3, 1, 1], [0, 3, 6, 6, 7]]]) ------------------------------ tensor([[4, 9, 3, 0, 3], [9, 7, 3, 7, 3], [1, 6, 6, 9, 8], [6, 6, 8, 4, 3]]) ------------------------------ tensor([[4, 9, 3, 0, 3], [6, 9, 1, 4, 4], [8, 3, 6, 9, 1]]) ------------------------------ tensor([[4, 9, 1, 6], [6, 1, 1, 5], [8, 7, 0, 0]]) ```
标签:
pytorch
非特殊说明,本博所有文章均为博主原创。
如若转载,请注明出处:
https://lilinchao.com/archives/2946.html
上一篇
05.Pytorch张量维度操作(一)
下一篇
01.Ray分布式框架介绍
评论已关闭
栏目分类
随笔
2
Java
326
大数据
229
工具
35
其它
25
GO
48
NLP
8
标签云
Kafka
JavaScript
Map
数据结构
Hbase
NIO
Spring
Ray
Flink
JVM
Quartz
Kibana
栈
散列
国产数据库改造
gorm
Zookeeper
Stream流
Scala
Kubernetes
MyBatisX
Golang基础
持有对象
Livy
SpringCloud
MyBatis-Plus
Eclipse
FastDFS
Sentinel
SpringCloudAlibaba
友情链接
申请
范明明
庄严博客
Mx
陶小桃Blog
虫洞
评论已关闭