通过os.dup sys.stdout.fileno捕获标准输出,判断pytorch算子是否fallback到了cpu
某种设备在运行pytorch算子时,如果不支持会自动fallback到cpu,输出的tensor.device却不是cpu,我希望能获取到这个状态。本文通过捕获标准输出,根据终端是否输出fallback字符串,判断是否触发了fallback
一.代码
import threading
import sys
import os
class CheckFallback:
def __init__(self,enable=True):
self.is_fallback=False
self.enable=enable
if self.enable:
self.stdout_fileno_origin = sys.stdout.fileno()
self.stdout_fileno_dup = os.dup(self.stdout_fileno_origin)
self.stdout_pipe = os.pipe()
os.dup2(self.