如何找出指定索引的元素并替换
假设我现在有两个距离矩阵up和down,大小都是(batch_size,pomo_size,problem_size,problem_size)。
假设batch_size=1,pomo_size=2,problem_size=5。
up = tensor([[[[8, 5, 6, 8, 2],
[2, 9, 7, 7, 0],
[5, 9, 1, 7, 7],
[0, 5, 8, 6, 3],
[5, 9, 1, 0, 2]],
[[2, 8, 1, 4, 7],
[0, 0, 2, 2, 7],
[4, 7, 7, 9, 4],
[6, 6, 7, 1, 3],
[3, 9, 9, 7, 2]]]])
down = tensor([[[[6, 1, 7, 9, 1],
[7, 6, 2, 7, 9],
[5, 3, 8, 6, 3],
[0, 8, 6, 3, 3],
[8, 9, 4, 8, 1]],
[[7, 2, 0, 6, 0],
[6, 7, 5, 3, 9],
[4, 8, 6, 6, 1],
[9, 3, 2, 2, 5],
[2, 5, 2, 1, 8]]]])
现在还有一个大小为(batch_size,pomo_size,seleced_length)的序列link,假设selected_length = 8
link = tensor([0,1,2,3,4,1,3,2]
[4,1,0,3,4,0,2,4])
我希望在down中找到link序列表示的边,然后将对应的边的值替换为up的值,拿第一个batch的第一个pomo举例,也就是在up中找到边元素(0,1)=5 (1,2)=7 (2,3)=7 (3,4)=2 (4,1)=9 (1,3)=7 (3,2)=8 (2,0)=5
替换成down中对应元素,得到的第一个batch,第一个pomo的矩阵为
[[6, 5, 7, 9, 1],
[5, 6, 7, 7, 9],
[5, 3, 8, 7, 3],
[0, 8, 8, 3, 2],
[9, 9, 4, 8, 1]],
按照你举例的数据,down的结果是啥?这个对吗?
代码是这样的,如果不对,可以调整一下:
import torch
# 定义输入张量
up = torch.tensor([[[[8, 5, 6, 8, 2],
[2, 9, 7, 7, 0],
[5, 9, 1, 7, 7],
[0, 5, 8, 6, 3],
[5, 9, 1, 0, 2]],
[[2, 8, 1, 4, 7],
[0, 0, 2, 2, 7],
[4, 7, 7, 9, 4],
[6, 6, 7, 1, 3],
[3, 9, 9, 7, 2]]]])
down = torch.tensor([[[[6, 1, 7, 9, 1],
[7, 6, 2, 7, 9],
[5, 3, 8, 6, 3],
[0, 8, 6, 3, 3],
[8, 9, 4, 8, 1]],
[[7, 2, 0, 6, 0],
[6, 7, 5, 3, 9],
[4, 8, 6, 6, 1],
[9, 3, 2, 2, 5],
[2, 5, 2, 1, 8]]]])
link = torch.tensor([[[0, 1, 2, 3, 4, 1, 3, 2],
[4, 1, 0, 3, 4, 0, 2, 4]]])
# 获取输入张量的形状信息
batch_size, pomo_size, problem_size, _ = up.size()
selected_length = link.size(2)
# 根据link序列在down中找到对应的边并替换为up的值
for i in range(batch_size):
for j in range(pomo_size):
for k in range(selected_length - 1):
from_node = link[i, j, k]
to_node = link[i, j, k+1]
down[i, j, to_node] = up[i, j, from_node, to_node]
# 打印结果
print(down)
我需要更具体的信息才能回答如何找出指定索引的元素并替换的问题。请提供以下详细信息: 1. 在哪个具体的场景中需要进行查找和替换操作? 2. 需要进行查找的元素是什么类型的?是字符串还是列表等其他类型? 3. 需要替换的元素又是什么类型的? 4. 您希望chatGPT给出什么样的回答?例如,您期望得到代码演示或是文字说明等形式的回答。
您要的是这样子的效果么?
#!/sur/bin/nve python
# coding: utf-8
def show_array(lis):
for i in lis:
print(' '.join(map(str, i)))
def show_2():
for i in ('up', 'down'):
print(f'\n{i}:')
for j in eval(i):
show_array(j)
print()
def change(link, target, lis):
n = len(link)
for i in range(n):
a, b = (i, 0) if i == n-1 else (i, i+1)
a, b = link[a], link[b]
target[a][b] = lis[a][b]
def main():
print('\nlink:')
show_array(link)
show_2()
print(f"{'':~^41}")
pomo = down[:] # 用down的副赋值给pomo。
for i in (0, 1):
change(link[i], pomo[i], up[i])
print('\npomo:')
for i in pomo:
show_array(i)
print('\n')
if __name__ == '__main__':
link = ([0,1,2,3,4,1,3,2],
[4,1,0,3,4,0,2,4])
up = [[[8, 5, 6, 8, 2],
[2, 9, 7, 7, 0],
[5, 9, 1, 7, 7],
[0, 5, 8, 6, 3],
[5, 9, 1, 0, 2]],
[[2, 8, 1, 4, 7],
[0, 0, 2, 2, 7],
[4, 7, 7, 9, 4],
[6, 6, 7, 1, 3],
[3, 9, 9, 7, 2]]]
down = [[[6, 1, 7, 9, 1],
[7, 6, 2, 7, 9],
[5, 3, 8, 6, 3],
[0, 8, 6, 3, 3],
[8, 9, 4, 8, 1]],
[[7, 2, 0, 6, 0],
[6, 7, 5, 3, 9],
[4, 8, 6, 6, 1],
[9, 3, 2, 2, 5],
[2, 5, 2, 1, 8]]]
main() # 调用主程序。