李林超博客
首页
归档
留言
友链
动态
关于
归档
留言
友链
动态
关于
首页
NLP
正文
04.Pytorch张量类型转换
Leefs
2024-12-15 AM
71℃
0条
[TOC] ### 一、张量转换为 numpy 数组 - 使用`Tensor.numpy`函数可以将张量转换为ndarray数组,但是共享内存,可以使用copy函数避免共享。 **代码示例** ```python import torch # 1. 张量转换为numpy数组 def test01(): data_tensor = torch.tensor([2,3,4]) # 将张量转换为numpy数组 data_numpy = data_tensor.numpy() print(type(data_tensor)) #
print(type(data_numpy)) #
print(data_tensor) #tensor([2, 3, 4]) print(data_numpy) #[2 3 4] # 2. 张量和numpy数组共享内存 def test02(): data_tensor = torch.tensor([2, 3, 4]) data_numpy = data_tensor.numpy() # 修改张量元素的值,看看numpy数组是否会发生变化? 会发生变化 # data_tensor[0] = 100 # print(data_tensor) #tensor([100, 3, 4]) # print(data_numpy) #[100 3 4] # 修改numpy数组元素的值,看看张量是否会发生变化? 会发生变化 data_numpy[0] = 100 print(data_tensor) # tensor([100, 3, 4]) print(data_numpy) # [100 3 4] # 3. 使用copy函数实现不共享内存 def test03(): data_tensor = torch.tensor([2, 3, 4]) # 此处,发生了类型转换,可以使用拷贝函数产生新的数据,避免共享内存 data_numpy = data_tensor.numpy().copy() # 修改张量元素的值,看看numpy数组是否会发生变化? 没有发生变化 # data_tensor[0] = 100 # print(data_tensor) # tensor([100, 3, 4]) # print(data_numpy) # [2 3 4] # 修改numpy数组元素的值,看看张量是否会发生变化? 没有发生变化 data_numpy[0] = 100 print(data_tensor) # tensor([2, 3, 4]) print(data_numpy) # [100 3 4] if __name__ == '__main__': # test01() # test02() test03() ``` ### 二、numpy 转换为张量 - 使用`from_numpy`函数可以将ndarray数组转换为Tensor,默认共享内存,使用copy函数避免共享内存 - 使用`torch.tensor`可以将ndarray数组转化为Tensor,默认不共享内存 **代码示例** ```python import torch import numpy as np # 1. from_numpy 函数的用法 def test01(): data_numpy = np.array([2,3,4]) data_tensor = torch.from_numpy(data_numpy.copy()) print(type(data_numpy)) print(type(data_tensor)) # 默认共享内存 data_numpy[0] = 100 # data_tensor[0] = 100 print(data_numpy) #[100 3 4] print(data_tensor) #tensor([2, 3, 4], dtype=torch.int32) # 2. torch.tensor 函数的用法 def test02(): data_numpy = np.array([2, 3, 4]) data_tensor = torch.tensor(data_numpy) # 默认共享内存 data_numpy[0] = 100 # data_tensor[0] = 100 print(data_numpy) # [100 3 4] print(data_tensor) # tensor([2, 3, 4], dtype=torch.int32) if __name__ == '__main__': # test01() test02() ``` ### 三、标量张量和数字的转换 - 对于只有一个元素的张量,使用item方法将该值从张量中提取出来。 ```python import torch def test(): t1=torch.tensor(30) t2=torch.tensor([30]) t3=torch.tensor([[30]]) print(t1.shape) #torch.Size([]) print(t2.shape) #torch.Size([1]) print(t3.shape) #torch.Size([1, 1]) print(t1.item()) #30 print(t2.item()) #30 print(t3.item()) #30 #注意:张量中只有一个元素,如果有多个元素的话,使用item函数会报错 # ValueError: only one element tensors can be converted to Python scalars # t4 = torch.tensor([30,40]) # print(t4.item()) if __name__ == '__main__': test() ```
标签:
pytorch
非特殊说明,本博所有文章均为博主原创。
如若转载,请注明出处:
https://lilinchao.com/archives/2943.html
上一篇
03.Pytorch张量数值计算
下一篇
没有了
评论已关闭
栏目分类
随笔
2
Java
326
大数据
229
工具
31
其它
25
GO
47
NLP
4
标签云
Ubuntu
正则表达式
JavaScript
Golang基础
链表
数学
Typora
高并发
Hbase
Flink
gorm
FileBeat
队列
Spark SQL
Java工具类
字符串
JVM
Map
RSA加解密
Flume
Spark
人工智能
Filter
ajax
FastDFS
Quartz
Jquery
JavaWEB项目搭建
JavaWeb
随笔
友情链接
申请
范明明
庄严博客
Mx
陶小桃Blog
虫洞
评论已关闭