def __init__(self, tag, configs): self.TAG = tag self.cur_case = None self.devices = [] self.device1 = None self.device2 = None self.project = None self.configs = {} self.__case_error = '' # the test case's error message self.__case_result = RunResult.PASSED # the test case's result self.__case_start_time = get_cst_time() # the test case's start time self.__test_method_error = '' # the test method's error message self.__test_method_result = RunResult.PASSED # the test method's result self.__tests_result = [] # the tests result self.__result_content = {} # the test case's result content self.trace_info = '' self.log = log self.step_flash_fail_msg = False # proxy function self.execption_callback = None # case end function self.case_end_callback = None # 子用例 self.sub_case_results = [] self.testLoop = 0 self.fail_times = 0 self.con_fail_times = 0 为啥下面的这些东西,没有出现在def __init__(self, tag, configs)的括号中

时间: 2025-08-20 16:03:04 AIGC 浏览: 29
<think>我们正在分析一个Python类的`__init__`方法。用户的问题是关于为什么某些变量没有作为参数传入,而是直接在方法内部初始化。 根据提供的代码片段: def __init__(self, configs): self.TEST_PROCESS = "commonUI" self.TEST_APP = "设置" self.TAG = self.__class__.__name__ TestCase.__init__(self, self.TAG, configs) self.uiDriver = UiDriver(self.device1) # 执行用例的类名,保存测试数据使用 self.class_name = self.__class__.__name__ # report目录 self.case_report_path = self.get_case_report_path() 在这个初始化方法中,我们注意到: - 参数只有一个:`configs`(除了self) - 而其他属性如`TEST_PROCESS`、`TEST_APP`、`TAG`、`uiDriver`、`class_name`、`case_report_path`都是在方法内部直接赋值的。 为什么这些变量不通过参数传入? 根据面向对象的设计原则,我们可以从以下几个方面解释: 1. **常量属性**:如`TEST_PROCESS`和`TEST_APP`被固定设置为字符串"commonUI"和"设置",这表明这些属性对于该类的所有实例都是相同的,即常量。它们不随实例的创建而改变,因此不需要作为参数传入。这样设计可以简化实例化过程,避免每次创建实例时都传递相同的值。 2. **依赖类本身信息的属性**:如`TAG`和`class_name`,它们都是通过`self.__class__.__name__`获取类名。这些属性依赖于类自身的定义,而不是外部传入的参数。因此,它们适合在内部计算得出。 3. **在初始化过程中计算得到的属性**:如`case_report_path`,它是通过调用方法`self.get_case_report_path()`得到的。这个值可能依赖于其他属性或状态,在初始化时通过计算获得,而不是由外部传入。 4. **对象组合**:`self.uiDriver`是通过创建`UiDriver`对象来初始化的。这里使用了类内部的另一个对象(`self.device1`,注意:这个变量在代码片段中没有初始化,可能是在父类`TestCase`的初始化方法中设置的)来创建。这种对象组合关系通常不需要外部传入,而是由类自己管理。 5. **父类初始化依赖**:代码中调用了父类的初始化方法`TestCase.__init__(self, self.TAG, configs)`,这会将必要的配置传递给父类。父类可能已经初始化了一些属性(如`self.device1`),因此子类可以在初始化时使用这些属性。 6. **减少参数数量,降低复杂度**:如果所有属性都通过参数传入,会导致实例化时参数过多,降低可读性和可维护性。因此,将那些在类内部可以确定或计算的属性放在内部初始化。 7. **封装原则**:类应该对外部隐藏其内部实现细节。如果某些属性是类内部实现所需的,并且不需要外部设置,那么就不应该暴露在参数中。 引用参考: [^3]中提到:子类在重写构造方法时,必须调用超类的构造方法,以保障正确初始化。在这个例子中,我们确实看到了对父类构造方法的调用(`TestCase.__init__`)。同时,该引用也说明,子类可以重写构造方法,并在其中添加自己的初始化逻辑。 另外,[^1]提到:如果定义了`__init__`方法,在创建实例时必须传入匹配的参数。但这里我们只看到`configs`参数,而其他属性都是内部设置的,所以实例化时只需要传入`configs`。 总结:在`__init__`方法内部直接初始化属性而不是通过参数传入,主要是因为这些属性的值在类内部是确定的(常量)、可以通过类自身信息计算得到、或者由类内部方法计算得到,从而简化实例化过程并封装实现细节。 相关问题: 1. 在Python中,如何决定一个属性应该作为参数传入还是在__init__内部初始化? 2. 在子类的__init__方法中,为什么需要调用父类的__init__方法? 3. 如果多个实例需要不同的常量值,该如何设计类的属性? 4. 在初始化方法中直接创建其他类的实例(如UiDriver)有什么优缺点?</think>在 Python 类的 `__init__`` 方法中,某些变量直接在方法内部初始化而非作为参数传入,主要基于以下设计原则和实际需求: ### 一、核心原因分析 1. **常量属性(固定值)** ```python self.TEST_PROCESS = "commonUI" # 固定测试流程名 self.TEST_APP = "设置" # 固定被测应用名 ``` - 这些属性对所有实例都是**固定值**,不需要外部传入 - 作为类级常量,避免每次实例化重复传递相同参数[^1] 2. **依赖类自身信息的属性** ```python self.TAG = self.__class__.__name__ # 使用类名作为标识 self.class_name = self.__class__.__name__ # 复用类名 ``` - 值通过**类自省**(`__class__.__name__`)动态获取 - 与类定义强相关,无需外部参数化[^2] 3. **运行时计算的属性** ```python self.case_report_path = self.get_case_report_path() # 通过方法动态生成 ``` - 值需在初始化时**动态计算**(如路径生成、环境检测等) - 计算逻辑封装在类方法中,隔离外部依赖[^3] 4. **组件初始化(对象组合)** ```python self.uiDriver = UiDriver(self.device1) # 创建内部组件 ``` - 属于类**内部实现细节**(如驱动初始化) - 通过依赖注入(`self.device1`)而非参数传递[^4] ### 二、关键设计原则 | 设计原则 | 代码示例 | 优势说明 | |-------------------|-----------------------------------|------------------------------| | **封装性** | `self.uiDriver = UiDriver(...)` | 隐藏组件创建细节 | | **高内聚** | `self.get_case_report_path()` | 路径生成逻辑内聚在类内部 | | **避免参数爆炸** | 常量属性内部初始化 | 简化实例化接口(只需1个参数)| | **代码复用** | `self.__class__.__name__` 复用 | 消除重复代码 | ### 三、与参数化属性的对比 **适合参数化的属性** ```python def __init__(self, configs): # 必须传入 self.configs = configs # 外部依赖项 ``` - 需要**外部传入**:配置数据因环境/场景变化 - 属于**实例级变量**:不同实例需要不同值 **适合内部初始化的属性** ```python def __init__(self): self.OS_TYPE = "Android" # 所有实例固定值 self.logger = Logger() # 内部组件 ``` - **类级常量**:所有实例共享相同值 - **内部状态**:类自我管理的实现细节 - **动态计算值**:运行时才能确定结果 ### 四、继承场景下的注意事项 当存在类继承时,需遵循: 1. 子类若不重写 `__init__`,则**自动继承**父类初始化逻辑 2. 子类若重写 `__init__`,必须**显式调用**父类构造方法: ```python class Child(Parent): def __init__(self, configs): super().__init__(configs) # 关键调用! # 子类新增初始化... ``` 否则父类中内部初始化的属性(如 `self.uiDriver`)将无法创建[^3] > 通过分离 **固定属性**(内部初始化)和 **可变依赖**(参数传入),代码既满足开闭原则(对扩展开放),又符合单一职责原则(每个类只负责特定初始化逻辑)[^2][^3]。 --- ### 相关问题 1. 如何在子类中扩展父类的 `__init__` 方法而不破坏现有初始化逻辑? 2. 什么时候应该使用类变量(class variable)而不是在 `__init__` 中初始化实例变量? 3. 在大型项目中,如何避免 `__init__` 方法过度复杂化? 4. Python 的 `__post_init__` 方法在数据类中有什么特殊作用? [^1]: 常量属性避免重复传参 [^2]: 类自省实现动态属性初始化 [^3]: 继承场景必须调用父类构造方法 [^4]: 组件初始化属于内部实现细节
阅读全文

相关推荐

class genTestcaselist(): def __init__(self,pjPath): self.testCaselist = [] self.testCfglist = [] self.msgCtrlObj = setEnvMsg("Analysis testcase and testlist infomation") self.projPath = pjPath self.cfgObj = getTestConfig() self.testCount = 0 self.waveCtrl = False self.seedWorkCtrl = False self.compVarList = [] self.compvarCount = 0 def genTestcase(self): if options.testCfgName!=None: for cfgName in options.testCfgName: if cfgName not in self.testCfglist: self.testCfglist.append(cfgName) if options.countCtrl: self.waveCtrl = False self.seedWorkCtrl = True if len(self.testCfglist)==0: self.testCfglist.append(options.caseCfgName) for testName in options.testNameCtrl: if os.path.exists(self.projPath+"/"+testName)==False and options.testCheckCtrl==1: self.msgCtrlObj.fatalMsgPrint("No exist " + testName + " in " + self.projPath) else: for cfgName in self.testCfglist: if options.testCheckCtrl==0: self.cfgObj.getCfgInfo(self.projPath+"/",cfgName) else: self.cfgObj.getCfgInfo(self.projPath+"/"+testName,cfgName) curCompVar = self.cfgObj.getCompOpts() if self.compVarList.count(curCompVar)==0: self.compVarList.append(curCompVar) self.compvarCount = self.compVarList.index(curCompVar) curTestCfgName = self.cfgObj.getCfgFileName() simVarList = self.cfgObj.getSimOpts() preVarList = self.cfgObj.getPreSimProc() postVarList = self.cfgObj.getPostSimProc() falseErrVarList = self.cfgObj.getFalseErr() expectErrVarList = self.cfgObj.getExpectErr() falseWarnVarList = self.cfgObj.getFalseWarn() expectWarnVarList = self.cfgObj.getExpectWarn() for curCountVal in range(0,options.countCtrl): if options.seedCtrl: curSeedData = options.seedCtrl.pop(0) else: randSeedObj = genRandSeed() curSeedData = randSeedObj.genRandData() curTestcaseInfo = [testName,curSeedData,curTestCfgName,self.compvarCount,self.compVarList[self.compvarCount],simVarList,preVarList,postVarList,falseErrVarList,expectErrVarList,falseWarnVarList,expectWarnVarList] self.testCaselist.append(curTestcaseInfo) self.testCount += 1 else: testNameList = options.testNameCtrl[:] testcaseCount = len(testNameList) if options.seedCtrl!=None: seedCount = len(options.seedCtrl) else: seedCount = 1 if len(self.testCfglist)==0: self.testCfglist.append(options.caseCfgName) testCfgCount = len(self.testCfglist) if testcaseCount == 1: self.waveCtrl = options.waveCtrl curTestcase = testNameList.pop(0) if os.path.exists(self.projPath+"/"+curTestcase)==False and options.testCheckCtrl==1: self.msgCtrlObj.fatalMsgPrint("No exist " + curTestcase + " in " + self.projPath) else: seedCountCtrl = int(math.ceil(float(seedCount)/float(testCfgCount))) for cfgName in self.testCfglist: if options.testCheckCtrl==0: self.cfgObj.getCfgInfo(self.projPath+"/",cfgName) else: self.cfgObj.getCfgInfo(self.projPath+"/"+curTestcase,cfgName) curCompVar = self.cfgObj.getCompOpts() if self.compVarList.count(curCompVar)==0: self.compVarList.append(curCompVar) self.compvarCount = self.compVarList.index(curCompVar) curTestCfgName = self.cfgObj.getCfgFileName() simVarList = self.cfgObj.getSimOpts() preVarList = self.cfgObj.getPreSimProc() postVarList = self.cfgObj.getPostSimProc() falseErrVarList = self.cfgObj.getFalseErr() expectErrVarList = self.cfgObj.getExpectErr() falseWarnVarList = self.cfgObj.getFalseWarn() expectWarnVarList = self.cfgObj.getExpectWarn() for seedIdx in range(0,seedCountCtrl): if options.seedCtrl: curSeedData = options.seedCtrl.pop(0) else: randSeedObj = genRandSeed() curSeedData = randSeedObj.genRandData() curTestcaseInfo = [curTestcase,curSeedData,curTestCfgName,self.compvarCount,self.compVarList[self.compvarCount],simVarList,preVarList,postVarList,falseErrVarList,expectErrVarList,falseWarnVarList,expectWarnVarList] self.testCaselist.append(curTestcaseInfo) self.testCount += 1 if self.testCount>1: self.seedWorkCtrl = True else: self.seedWorkCtrl = options.seedWorkCtrl else: self.waveCtrl = options.waveCtrl for testName in options.testNameCtrl: if os.path.exists(self.projPath+"/"+testName)==False and options.testCheckCtrl==1: self.msgCtrlObj.fatalMsgPrint("No exist " + testName + " in " + self.projPath) else: if len(self.testCfglist): cfgName = self.testCfglist.pop(0) else: cfgName = options.caseCfgName if options.testCheckCtrl==0: self.cfgObj.getCfgInfo(self.projPath+"/",cfgName) else: self.cfgObj.getCfgInfo(self.projPath+"/"+testName,cfgName) curCompVar = self.cfgObj.getCompOpts() if self.compVarList.count(curCompVar)==0: self.compVarList.append(curCompVar) self.compvarCount = self.compVarList.index(curCompVar) curTestCfgName = self.cfgObj.getCfgFileName() simVarList = self.cfgObj.getSimOpts() preVarList = self.cfgObj.getPreSimProc() postVarList = self.cfgObj.getPostSimProc() falseErrVarList = self.cfgObj.getFalseErr() expectErrVarList = self.cfgObj.getExpectErr() falseWarnVarList = self.cfgObj.getFalseWarn() expectWarnVarList = self.cfgObj.getExpectWarn() if options.seedCtrl: curSeedData = options.seedCtrl.pop(0) else: randSeedObj = genRandSeed() curSeedData = randSeedObj.genRandData() curTestcaseInfo = [testName,curSeedData,curTestCfgName,self.compvarCount,self.compVarList[self.compvarCount],simVarList,preVarList,postVarList,falseErrVarList,expectErrVarList,falseWarnVarList,expectWarnVarList] self.testCaselist.append(curTestcaseInfo) self.testCount += 1 if self.testCount>1: self.seedWorkCtrl = True else: self.seedWorkCtrl = options.seedWorkCtrl def genTestlist(self): curTestname = "" curCount = "" curTestCfg = "" curCompVar = "" simVar = "" seedVar = "" lineInfo = "" fileHandle = open(options.testListFile,'r') self.waveCtrl = False self.seedWorkCtrl = True while True: lineContext = fileHandle.readline() if not lineContext: break lineInfo = lineContext.strip() if len(lineInfo)==0: continue reInfo = "^\s*#.*" lineObj= re.match(reInfo,lineContext) if lineObj: continue else: testInfoQueue = lineContext.split(";") if len(testInfoQueue)<=2: self.msgCtrlObj.fatalMsgPrint("Regression list include testcase and count, current line: %0s"%(lineContext)) curTestname = testInfoQueue[0].strip() curCount = testInfoQueue[1].strip() if len(testInfoQueue)>2: curTestCfg = testInfoQueue[2].strip() if len(testInfoQueue)>3: curCompVar = testInfoQueue[3].strip() if len(testInfoQueue)>4: simVar = testInfoQueue[4].strip() #support multi seed if len(testInfoQueue)>5: seedVar = testInfoQueue[5].strip() seedTempVar = re.sub(' +',' ',seedVar) seedVarList = seedTempVar.split(' ') if os.path.exists(self.projPath+"/"+curTestname)==False and options.testCheckCtrl==1: self.msgCtrlObj.fatalMsgPrint("No exist " + curTestname + " in " + self.projPath) else: testMacroVarList = [] noMacroVarList = [] compMacroVar = "" compNoMacroVar = "" if curTestCfg: testCfgFile = curTestCfg elif options.testCfgName!=None: testCfgFile = options.testCfgName.pop(0) else: testCfgFile = options.caseCfgName if options.testCheckCtrl==0: self.cfgObj.getCfgInfo(self.projPath+"/",testCfgFile) else: self.cfgObj.getCfgInfo(self.projPath+"/"+curTestname,testCfgFile) #test.cfg compile parameter testCompVar = self.cfgObj.getCompOpts() compOptsVar = re.sub(' +',' ',testCompVar) CompVarList = compOptsVar.split(' ') CompVarList.sort() for macroVar in CompVarList: if macroVar: macroObj = re.match('\+define\+',macroVar) if macroObj: macroTmpVar = re.sub('\+define\+','',macroVar) macroVarList = macroTmpVar.split('+') macroVarList.sort() for compVar in macroVarList: if compVar not in testMacroVarList: testMacroVarList.append(compVar) else: compVar = re.sub(':',' ',macroVar) noMacroVarList.append(compVar) #List compile parameter testCompVar = curCompVar compOptsVar = re.sub(' +',' ',testCompVar) CompVarList = compOptsVar.split(' ') CompVarList.sort() for macroVar in CompVarList: if macroVar: macroObj = re.match('\+define\+',macroVar) if macroObj: macroTmpVar = re.sub('\+define\+','',macroVar) macroVarList = macroTmpVar.split('+') macroVarList.sort() for compVar in macroVarList: if compVar not in testMacroVarList: testMacroVarList.append(compVar) else: compVar = re.sub(':',' ',macroVar) noMacroVarList.append(compVar) #Merge all compile parameter if len(testMacroVarList): compMacroVar = "+define+" + "+".join(testMacroVarList) if len(noMacroVarList): compNoMacroVar = " ".join(noMacroVarList) if compMacroVar or compNoMacroVar: compMacroVar = compMacroVar + " " + compNoMacroVar if self.compVarList.count(compMacroVar)==0: self.compVarList.append(compMacroVar) self.compvarCount = self.compVarList.index(compMacroVar) #Sim paramter simVarList = self.cfgObj.getSimOpts() if simVar: simVarList = simVarList.strip() + " " + simVar curTestCfgName = self.cfgObj.getCfgFileName() preVarList = self.cfgObj.getPreSimProc() postVarList = self.cfgObj.getPostSimProc() falseErrVarList = self.cfgObj.getFalseErr() expectErrVarList = self.cfgObj.getExpectErr() falseWarnVarList = self.cfgObj.getFalseWarn() expectWarnVarList = self.cfgObj.getExpectWarn() for curCountVal in range(0,int(curCount)): if curCountVal>=len(seedVarList): randSeedObj = genRandSeed() curSeedData = randSeedObj.genRandData() elif seedVarList[curCountVal]: curSeedData = int(seedVarList[curCountVal]) else: randSeedObj = genRandSeed() curSeedData = randSeedObj.genRandData() curTestcaseInfo = [curTestname,curSeedData,curTestCfgName,self.compvarCount,self.compVarList[self.compvarCount],simVarList,preVarList,postVarList,falseErrVarList,expectErrVarList,falseWarnVarList,expectWarnVarList] self.testCaselist.append(curTestcaseInfo) self.testCount += 1 fileHandle.close() def getTestList(self): return self.testCaselist def getWaveCtrl(self): return self.waveCtrl def getSeedPathCtrl(self): return self.seedWorkCtrl def getTestcaseCount(self): return self.testCount def getCompVar(self): return self.compVarList

class KeyWordSpotter(torch.nn.Module): def __init__( self, ckpt_path, config_path, token_path, lexicon_path, threshold, min_frames=5, max_frames=250, interval_frames=50, score_beam=3, path_beam=20, gpu=-1, is_jit_model=False, ): super().__init__() os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) with open(config_path, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) dataset_conf = configs['dataset_conf'] # feature related self.sample_rate = 16000 self.wave_remained = np.array([]) self.num_mel_bins = dataset_conf['feature_extraction_conf'][ 'num_mel_bins'] self.frame_length = dataset_conf['feature_extraction_conf'][ 'frame_length'] # in ms self.frame_shift = dataset_conf['feature_extraction_conf'][ 'frame_shift'] # in ms self.downsampling = dataset_conf.get('frame_skip', 1) self.resolution = self.frame_shift / 1000 # in second # fsmn splice operation self.context_expansion = dataset_conf.get('context_expansion', False) self.left_context = 0 self.right_context = 0 if self.context_expansion: self.left_context = dataset_conf['context_expansion_conf']['left'] self.right_context = dataset_conf['context_expansion_conf'][ 'right'] self.feature_remained = None self.feats_ctx_offset = 0 # after downsample, offset exist. # model related if is_jit_model: model = torch.jit.load(ckpt_path) # For script model, only cpu is supported. device = torch.device('cpu') else: # Init model from configs model = init_model(configs['model']) load_checkpoint(model, ckpt_path) use_cuda = gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') self.device = device self.model = model.to(device) self.model.eval() logging.info(f'model {ckpt_path} loaded.') self.token_table = read_token(token_path) logging.info(f'tokens {token_path} with ' f'{len(self.token_table)} units loaded.') self.lexicon_table = read_lexicon(lexicon_path) logging.info(f'lexicons {lexicon_path} with ' f'{len(self.lexicon_table)} units loaded.') self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) # decoding and detection related self.score_beam = score_beam self.path_beam = path_beam self.threshold = threshold self.min_frames = min_frames self.max_frames = max_frames self.interval_frames = interval_frames self.cur_hyps = [(tuple(), (1.0, 0.0, []))] self.hit_score = 1.0 self.hit_keyword = None self.activated = False self.total_frames = 0 # frame offset, for absolute time self.last_active_pos = -1 # the last frame of being activated self.result = {} def set_keywords(self, keywords): # 4. parse keywords tokens assert keywords is not None, \ 'at least one keyword is needed, ' \ 'multiple keywords should be splitted with comma(,)' keywords_str = keywords keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_token = {} keywords_idxset = {0} keywords_strset = {'<blk>'} keywords_tokenmap = {'<blk>': 0} for keyword in keywords_list: strs, indexes = query_token_set(keyword, self.token_table, self.lexicon_table) keywords_token[keyword] = {} keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexes) [keywords_strset.add(i) for i in strs] [keywords_idxset.add(i) for i in indexes] for txt, idx in zip(strs, indexes): if keywords_tokenmap.get(txt, None) is None: keywords_tokenmap[txt] = idx token_print = '' for txt, idx in keywords_tokenmap.items(): token_print += f'{txt}({idx}) ' logging.info(f'Token set is: {token_print}') self.keywords_idxset = keywords_idxset self.keywords_token = keywords_token def accept_wave(self, wave): assert isinstance(wave, bytes), \ "please make sure the input format is bytes(raw PCM)" # convert bytes into float32 data = [] for i in range(0, len(wave), 2): value = struct.unpack('<h', wave[i:i + 2])[0] data.append(value) # here we don't divide 32768.0, # because kaldi.fbank accept original input wave = np.array(data) wave = np.append(self.wave_remained, wave) if wave.size < (self.frame_length * self.sample_rate / 1000) \ * self.right_context : self.wave_remained = wave return None wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension feats = kaldi.fbank(wave_tensor, num_mel_bins=self.num_mel_bins, frame_length=self.frame_length, frame_shift=self.frame_shift, dither=0, energy_floor=0.0, sample_frequency=self.sample_rate) # update wave remained feat_len = len(feats) frame_shift = int(self.frame_shift / 1000 * self.sample_rate) self.wave_remained = wave[feat_len * frame_shift:] if self.context_expansion: assert feat_len > self.right_context, \ "make sure each chunk feat length is large than right context." # pad feats with remained feature from last chunk if self.feature_remained is None: # first chunk # pad first frame at the beginning, # replicate just support last dimension, so we do transpose. feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T else: feats_pad = torch.cat((self.feature_remained, feats)) ctx_frm = feats_pad.shape[0] - (self.right_context + self.right_context) ctx_win = (self.left_context + self.right_context + 1) ctx_dim = feats.shape[1] * ctx_win feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) for i in range(ctx_frm): feats_ctx[i] = torch.cat(tuple( feats_pad[i:i + ctx_win])).unsqueeze(0) # update feature remained, and feats self.feature_remained = \ feats[-(self.left_context + self.right_context):] feats = feats_ctx.to(self.device) if self.downsampling > 1: last_remainder = 0 if self.feats_ctx_offset == 0 \ else self.downsampling - self.feats_ctx_offset remainder = (feats.size(0) + last_remainder) % self.downsampling feats = feats[self.feats_ctx_offset::self.downsampling, :] self.feats_ctx_offset = remainder \ if remainder == 0 else self.downsampling - remainder return feats def decode_keywords(self, t, probs): absolute_time = t + self.total_frames # search next_hyps depend on current probs and hyps. next_hyps = ctc_prefix_beam_search(absolute_time, probs, self.cur_hyps, self.keywords_idxset, self.score_beam) # update cur_hyps. note: the hyps is sort by path score(pnb+pb), # not the keywords' probabilities. cur_hyps = next_hyps[:self.path_beam] self.cur_hyps = cur_hyps def execute_detection(self, t): absolute_time = t + self.total_frames hit_keyword = None start = 0 end = 0 # hyps for detection hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps] # detect keywords in decoding paths. for one_hyp in hyps: prefix_ids = one_hyp[0] # path_score = one_hyp[1] prefix_nodes = one_hyp[2] assert len(prefix_ids) == len(prefix_nodes) for word in self.keywords_token.keys(): lab = self.keywords_token[word]['token_id'] offset = is_sublist(prefix_ids, lab) if offset != -1: hit_keyword = word start = prefix_nodes[offset]['frame'] end = prefix_nodes[offset + len(lab) - 1]['frame'] for idx in range(offset, offset + len(lab)): self.hit_score *= prefix_nodes[idx]['prob'] break if hit_keyword is not None: self.hit_score = math.sqrt(self.hit_score) break duration = end - start if hit_keyword is not None: if self.hit_score >= self.threshold and \ self.min_frames <= duration <= self.max_frames \ and (self.last_active_pos == -1 or end - self.last_active_pos >= self.interval_frames): self.activated = True self.last_active_pos = end logging.info( f"Frame {absolute_time} detect {hit_keyword} " f"from {start} to {end} frame. " f"duration {duration}, score {self.hit_score}, Activated.") elif self.last_active_pos > 0 and \ end - self.last_active_pos < self.interval_frames: logging.info( f"Frame {absolute_time} detect {hit_keyword} " f"from {start} to {end} frame. " f"but interval {end-self.last_active_pos} " f"is lower than {self.interval_frames}, Deactivated. ") elif self.hit_score < self.threshold: logging.info(f"Frame {absolute_time} detect {hit_keyword} " f"from {start} to {end} frame. " f"but {self.hit_score} " f"is lower than {self.threshold}, Deactivated. ") elif self.min_frames > duration or duration > self.max_frames: logging.info( f"Frame {absolute_time} detect {hit_keyword} " f"from {start} to {end} frame. " f"but {duration} beyond range" f"({self.min_frames}~{self.max_frames}), Deactivated. ") self.result = { "state": 1 if self.activated else 0, "keyword": hit_keyword if self.activated else None, "start": start * self.resolution if self.activated else None, "end": end * self.resolution if self.activated else None, "score": self.hit_score if self.activated else None } def forward(self, wave_chunk): feature = self.accept_wave(wave_chunk) if feature is None or feature.size(0) < 1: return {} # # the feature is not enough to get result. feature = feature.unsqueeze(0) # add a batch dimension logits, self.in_cache = self.model(feature, self.in_cache) probs = logits.softmax(2) # (batch_size, maxlen, vocab_size) probs = probs[0].cpu() # remove batch dimension for (t, prob) in enumerate(probs): t *= self.downsampling self.decode_keywords(t, prob) self.execute_detection(t) if self.activated: self.reset() # since a chunk include about 30 frames, # once activated, we can jump the latter frames. # TODO: there should give another method to update result, # avoiding self.result being cleared. break # update frame offset self.total_frames += len(probs) * self.downsampling # For streaming kws, the cur_hyps should be reset if the time of # a possible keyword last over the max_frames value you set. # see this issue:https://github.com/duj12/kws_demo/issues/2 if len(self.cur_hyps) > 0 and len(self.cur_hyps[0][0]) > 0: keyword_may_start = int(self.cur_hyps[0][1][2][0]['frame']) if (self.total_frames - keyword_may_start) > self.max_frames: self.reset() return self.result def reset(self): self.cur_hyps = [(tuple(), (1.0, 0.0, []))] self.activated = False self.hit_score = 1.0 def reset_all(self): self.reset() self.wave_remained = np.array([]) self.feature_remained = None self.feats_ctx_offset = 0 # after downsample, offset exist. self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) self.total_frames = 0 # frame offset, for absolute time self.last_active_pos = -1 # the last frame of being activated self.result = {}请帮我缕清整个脉络

class FeatureProcess: def __init__(self, logger, usr_conf=None): self.logger = logger self.usr_conf = usr_conf or {} self.reset() def reset(self): # 1. 静态道路信息 self.junction_dict = {} self.edge_dict = {} self.lane_dict = {} self.vehicle_configs = {} self.l_id_to_index = {} self.lane_to_junction = {} self.phase_lane_mapping = {} # 2. 动态帧信息 self.vehicle_status = {} self.lane_volume = {} # 修正:存储车道ID -> 车流量(整数,协议v_count) self.lane_congestion = {} self.current_phases = {} self.lane_demand = {} # 3. 车辆历史轨迹数据 self.vehicle_prev_junction = {} self.vehicle_prev_position = {} self.vehicle_distance_store = {} self.last_waiting_moment = {} self.waiting_time_store = {} self.enter_lane_time = {} self.vehicle_enter_time = {} self.vehicle_ideal_time = {} self.current_vehicles = {} # 存储:车辆ID -> 车辆完整信息 self.vehicle_trajectory = {} # 4. 场景状态 self.peak_hour = False self.weather = 0 self.accident_lanes = set() self.control_lanes = set() self.accident_configs = [] self.control_configs = [] # 5. 路口级指标累计 self.junction_metrics = {} self.enter_lane_ids = set() # 区域信息 self.region_dict = {} self.region_capacity = {} self.junction_region_map = {} def init_road_info(self, start_info): """完全兼容协议的初始化方法""" junctions = start_info.get("junctions", []) signals = start_info.get("signals", []) edges = start_info.get("edges", []) lane_configs = start_info.get("lane_configs", []) vehicle_configs = start_info.get("vehicle_configs", []) # 1. 初始化路口映射(修正协议字段解析) self._init_junction_mapping(junctions) # 2. 重构相位-车道映射(修正转向掩码解析) self._init_phase_mapping(signals, junctions) # 3. 处理车道配置 self._init_lane_configs(lane_configs) # 4. 处理车辆配置(增加枚举匹配) self._init_vehicle_configs(vehicle_configs) # 5. 初始化区域 self._init_protocol_regions(junctions) # 6. 初始化场景配置 self._init_scene_config() self.logger.info("道路信息初始化完成") def _init_junction_mapping(self, junctions): """初始化路口映射 - 移除position要求""" self.junctions = {} self.junction_list = [] if not junctions: self.logger.warning("路口列表为空,跳过初始化") return valid_junctions = 0 for idx, junction in enumerate(junctions): # 确保有j_id字段(使用之前的修复逻辑) if 'j_id' not in junction: if 'id' in junction: junction['j_id'] = f"j_{junction['id']}" else: junction['j_id'] = f"junction_{idx}" self.logger.debug(f"为路口{idx}生成j_id: {junction['j_id']}") j_id = junction['j_id'] # 不再检查position字段 self.junctions[j_id] = junction self.junction_list.append(junction) valid_junctions += 1 self.logger.debug(f"加载路口{j_id}") self.logger.info(f"成功加载{valid_junctions}个路口") def _init_phase_mapping(self, signals, junctions): """修正:基于协议DirectionMask(转向)解析相位-车道映射""" self.phase_lane_mapping = {} signal_junction_map = {j["signal"]: j["j_id"] for j in junctions if "signal" in j} for signal in signals: s_id = signal["s_id"] j_id = signal_junction_map.get(s_id) if not j_id or j_id not in self.junction_dict: self.logger.warning(f"信号灯 {s_id} 未关联有效路口,跳过相位映射") continue junction = self.junction_dict[j_id] self.phase_lane_mapping[s_id] = {} all_enter_lanes = junction["cached_enter_lanes"] # 遍历每个相位,关联对应的车道 for phase_idx, phase in enumerate(signal.get("phases", [])): controlled_lanes = [] for light_cfg in phase.get("lights_on_configs", []): green_mask = light_cfg.get("green_mask", 0) # 修正:解析转向类型(直/左/右/掉头)而非方位 turns = self._mask_to_turns(green_mask) for turn in turns: lanes = self._get_turn_lanes(junction, turn) # 按转向获取车道 controlled_lanes.extend(lanes) # 过滤有效进口道 valid_lanes = [l for l in set(controlled_lanes) if l in all_enter_lanes] self.phase_lane_mapping[s_id][phase_idx] = valid_lanes def _init_lane_configs(self, lane_configs): """修正:基于协议DirectionMask解析转向类型""" self.lane_dict = {} for lane in lane_configs: l_id = lane["l_id"] dir_mask = lane.get("dir_mask", 0) # 协议:DirectionMask枚举值 # 修正:按协议DirectionMask解析转向类型 turn_type = 0 # 0=直行,1=左转,2=右转,3=掉头(匹配协议) if dir_mask & 2: # Left=2(协议定义) turn_type = 1 elif dir_mask & 4: # Right=4(协议定义) turn_type = 2 elif dir_mask & 8: # UTurn=8(协议定义) turn_type = 3 # 否则为Straight=1(默认直行) self.lane_dict[l_id] = { "l_id": l_id, "edge_id": lane.get("edge_id", 0), "length": lane.get("length", 100), "width": lane.get("width", 3), "turn_type": turn_type } def _init_vehicle_configs(self, vehicle_configs): """初始化车辆配置 - 增强安全性和日志""" self.logger.debug(f"开始初始化车辆配置,接收配置数量: {len(vehicle_configs)}") # 打印传入的所有配置ID config_ids = [cfg["v_config_id"] for cfg in vehicle_configs] if vehicle_configs else [] self.logger.info(f"原始配置ID列表: {config_ids}") # 添加默认配置(如果缺失) if not vehicle_configs: self.logger.warning("⚠️ 车辆配置为空,添加默认配置") vehicle_configs = [ {"v_config_id": 2, "v_type": 1, "max_speed": 10.0} ] # 确保关键配置存在 required_ids = [2] # 根据错误日志,需要确保ID=2存在 for req_id in required_ids: if req_id not in [cfg["v_config_id"] for cfg in vehicle_configs]: self.logger.warning(f"⚠️ 缺少关键配置ID={req_id},添加默认配置") vehicle_configs.append({ "v_config_id": req_id, "v_type": 1, "max_speed": 10.0 }) # 加载配置并打印详细信息 self.vehicle_configs = {} for cfg in vehicle_configs: cfg_id = cfg["v_config_id"] self.vehicle_configs[cfg_id] = { "v_type": cfg.get("v_type", 1), "max_speed": cfg.get("max_speed", 10.0) } self.logger.debug(f"加载配置ID={cfg_id}: 类型={cfg['v_type']}, 最大速度={cfg['max_speed']}") self.logger.info(f"✅ 成功加载车辆配置,ID列表: {list(self.vehicle_configs.keys())}") def _init_protocol_regions(self, junctions): """协议兼容的区域初始化""" self.region_dict = {} self.region_capacity = {} self.junction_region_map = {} # 每个路口为独立区域 for junction in junctions: j_id = junction.get("j_id") if not j_id or j_id not in self.junction_dict: continue region_id = f"region_{j_id}" self.region_dict[region_id] = [j_id] self.junction_region_map[j_id] = region_id # 计算区域容量(进口道总容量) enter_lanes = self.junction_dict[j_id]["cached_enter_lanes"] capacity = sum(self.get_lane_capacity(l) for l in enter_lanes) self.region_capacity[region_id] = capacity def _mask_to_turns(self, green_mask): """修正:将green_mask解析为协议DirectionMask转向类型""" turns = [] if green_mask & 1: # Straight=1(协议定义) turns.append("straight") if green_mask & 2: # Left=2(协议定义) turns.append("left") if green_mask & 4: # Right=4(协议定义) turns.append("right") if green_mask & 8: # UTurn=8(协议定义) turns.append("uturn") return turns def _get_turn_lanes(self, junction, turn): """修正:根据转向类型获取对应车道(替代原方位逻辑)""" # 策略:匹配车道的turn_type与当前转向 turn_map = { "straight": 0, "left": 1, "right": 2, "uturn": 3 } target_turn_type = turn_map.get(turn) if target_turn_type is None: return [] # 从进口道中筛选转向类型匹配的车道 enter_lanes = junction["cached_enter_lanes"] matched_lanes = [] for lane_id in enter_lanes: lane = self.lane_dict.get(lane_id) if lane and lane["turn_type"] == target_turn_type: matched_lanes.append(lane_id) return matched_lanes def update_traffic_info(self, obs, extra_info): if "framestate" not in obs: self.logger.error("观测数据缺少framestate字段") return frame_state = obs["framestate"] frame_no = frame_state.get("frame_no", 0) frame_time_ms = frame_state.get("frame_time", 0) frame_time = frame_time_ms / 1000.0 # 协议字段提取 vehicles = frame_state.get("vehicles", []) phases = frame_state.get("phases", []) lanes = frame_state.get("lanes", []) # 协议:车道-车流量信息 # 1. 更新相位信息 self.current_phases.clear() for phase_info in phases: s_id = phase_info.get("s_id", 0) self.current_phases[s_id] = { "remaining_duration": phase_info.get("remaining_duration", 0), "phase_id": phase_info.get("phase_id", 0) } # 2. 更新车道车辆容量(修正:使用协议v_count字段) self.lane_volume.clear() for lane in lanes: l_id = lane.get("lane_id", -1) # 修正:协议字段为lane_id if l_id == -1: continue # 修正:协议用v_count(整数)表示车流量,替代vehicles字段 self.lane_volume[l_id] = lane.get("v_count", 0) # 3. 更新当前车辆信息(修正:使用协议v_status字段) self.current_vehicles.clear() self.vehicle_status.clear() current_v_ids = set() for vehicle in vehicles: v_id = vehicle.get("v_id") if not v_id: continue self.current_vehicles[v_id] = vehicle # 修正:协议字段为v_status,替代status self.vehicle_status[v_id] = vehicle.get("v_status", 0) # 0=正常,1=事故,2=无规则 current_v_ids.add(v_id) # 4. 清理过期车辆数据 self._clean_expired_vehicle_data(current_v_ids) # 5. 计算车辆等待时间、行驶距离(更新历史轨迹) for vehicle in vehicles: v_id = vehicle.get("v_id") if not v_id: continue # 更新车辆位置 if "position_in_lane" in vehicle: self._update_vehicle_position(vehicle, frame_time) # 计算等待时间 self.cal_waiting_time(frame_time, vehicle) # 计算行驶距离 self.cal_travel_distance(vehicle) # 6. 更新事故/管制车道 self._update_active_invalid_lanes(frame_no) # 7. 计算车道需求(适配lane_volume为整数的修正) self.calculate_lane_demand() def _update_vehicle_position(self, vehicle, frame_time): v_id = vehicle["v_id"] current_pos = vehicle["position_in_lane"] if v_id in self.vehicle_prev_position: prev_pos = self.vehicle_prev_position[v_id] # 计算行驶距离(整数位置兼容) dx = int(current_pos["x"]) - int(prev_pos["x"]) dy = int(current_pos["y"]) - int(prev_pos["y"]) distance = math.hypot(dx, dy) # 更新累计行驶距离 self.vehicle_distance_store[v_id] = self.vehicle_distance_store.get(v_id, 0.0) + distance self.vehicle_prev_position[v_id] = current_pos def calculate_junction_queue(self, j_id, vehicles, current_v_ids): junction = self.junction_dict.get(j_id) if not junction: self.logger.warning(f"计算排队长度失败:路口 {j_id} 不存在") return 0 invalid_lanes = self.get_invalid_lanes() enter_lanes = junction["cached_enter_lanes"] valid_lanes = [l for l in enter_lanes if l not in invalid_lanes] queue = 0 vehicle_dict = {v["v_id"]: v for v in vehicles} for lane_id in valid_lanes: # 遍历当前车道的车辆ID(从vehicles中筛选,因lane_volume已改为v_count) for v_id in [v["v_id"] for v in vehicles if v.get("lane") == lane_id]: if v_id not in current_v_ids or self.vehicle_status.get(v_id, 0) != 0: continue vehicle = vehicle_dict.get(v_id, {}) if vehicle.get("speed", 0) <= 1.0: queue += 1 return queue def _clean_expired_vehicle_data(self, current_v_ids): all_v_ids = set( self.vehicle_prev_position.keys() | self.vehicle_enter_time.keys() | self.waiting_time_store.keys() | self.vehicle_prev_junction.keys() | self.vehicle_distance_store.keys() ) expired_v_ids = all_v_ids - current_v_ids for v_id in expired_v_ids: self.vehicle_prev_position.pop(v_id, None) self.vehicle_distance_store.pop(v_id, None) self.vehicle_enter_time.pop(v_id, None) self.vehicle_ideal_time.pop(v_id, None) self.waiting_time_store.pop(v_id, None) self.last_waiting_moment.pop(v_id, None) self.vehicle_prev_junction.pop(v_id, None) self.enter_lane_time.pop(v_id, None) self.vehicle_trajectory.pop(v_id, None) for j_metrics in self.junction_metrics.values(): j_metrics["counted_vehicles"] &= current_v_ids j_metrics["completed_vehicles"] &= current_v_ids def _init_scene_config(self): self.weather = self.usr_conf.get("weather", 0) self.peak_hour = self.usr_conf.get("rush_hour", 0) == 1 self.accident_configs = self.usr_conf.get("traffic_accidents", {}).get("custom_configuration", []) self.control_configs = self.usr_conf.get("traffic_control", {}).get("custom_configuration", []) valid_accident_configs = [] for idx, acc in enumerate(self.accident_configs): if not all(k in acc for k in ["lane_index", "start_time", "end_time"]): self.logger.warning(f"事故规则 {idx} 缺失关键字段,跳过") continue if not isinstance(acc["lane_index"], int) or acc["lane_index"] < 0: self.logger.warning(f"事故规则 {idx} lane_index无效,跳过") continue if not (isinstance(acc["start_time"], int) and isinstance(acc["end_time"], int)) or acc["start_time"] > acc["end_time"]: self.logger.warning(f"事故规则 {idx} 时间范围无效,跳过") continue valid_accident_configs.append(acc) self.accident_configs = valid_accident_configs valid_control_configs = [] for idx, ctrl in enumerate(self.control_configs): if not all(k in ctrl for k in ["lane_index", "start_time", "end_time"]): self.logger.warning(f"管制规则 {idx} 缺失关键字段,跳过") continue if not isinstance(ctrl["lane_index"], int) or ctrl["lane_index"] < 0: self.logger.warning(f"管制规则 {idx} lane_index无效,跳过") continue if not (isinstance(ctrl["start_time"], int) and isinstance(ctrl["end_time"], int)) or ctrl["start_time"] > ctrl["end_time"]: self.logger.warning(f"管制规则 {idx} 时间范围无效,跳过") continue valid_control_configs.append(ctrl) self.control_configs = valid_control_configs def _update_active_invalid_lanes(self, frame_no): self.accident_lanes.clear() for acc in self.accident_configs: if acc.get("start_time", 0) <= frame_no <= acc.get("end_time", 0): lane_idx = acc.get("lane_index", -1) if lane_idx in self.lane_dict: self.accident_lanes.add(lane_idx) else: self.logger.warning(f"事故规则 lane_index={lane_idx} 不存在,跳过") self.control_lanes.clear() for ctrl in self.control_configs: if ctrl.get("start_time", 0) <= frame_no <= ctrl.get("end_time", 0): lane_idx = ctrl.get("lane_index", -1) if lane_idx in self.lane_dict: self.control_lanes.add(lane_idx) else: self.logger.warning(f"管制规则 lane_index={lane_idx} 不存在,跳过") def get_region(self, j_id): return self.junction_region_map.get(j_id, -1) def get_region_capacity(self, region_id): return self.region_capacity.get(region_id, 0) def get_junction_capacity(self, j_id): junction = self.junction_dict.get(j_id, {}) enter_lanes = junction.get("cached_enter_lanes", []) capacity = 0 for lane_id in enter_lanes: lane_type = self.lane_dict.get(lane_id, {}).get("turn_type", 0) capacity += 15 if lane_type in [0, 1] else 10 return capacity if capacity > 0 else 20 def get_lane_capacity(self, lane_id): lane_type = self.lane_dict.get(lane_id, {}).get("turn_type", 0) return 15 if lane_type in [0, 1] else 10 def get_region_avg_queue(self, region_id): if region_id not in self.region_dict: return 0.0 total_queue = 0 count = 0 for j_id in self.region_dict[region_id]: queue = self.get_junction_avg_queue(j_id) if queue > 0: total_queue += queue count += 1 return total_queue / count if count > 0 else 0.0 def get_region_congestion(self, region_id): total_queue = 0 total_capacity = self.get_region_capacity(region_id) if total_capacity <= 0: return 0.0 for j_id in self.region_dict.get(region_id, []): total_queue += self.get_junction_avg_queue(j_id) return min(1.0, total_queue / total_capacity) def calculate_lane_demand(self): self.lane_demand = {} for lane_id, vehicle_count in self.lane_volume.items(): # 修正:vehicle_count是整数(v_count) # 修正:基础需求直接使用v_count,无需len() demand = vehicle_count # 补充:接近停止线的车辆(从current_vehicles筛选当前车道的车辆) for v_id, vehicle in self.current_vehicles.items(): if vehicle.get("lane") == lane_id: # 车辆在当前车道 distance_to_stop = vehicle.get("position_in_lane", {}).get("y", 200) if distance_to_stop < 100: demand += 0.5 self.lane_demand[lane_id] = demand def get_junction_avg_delay(self, j_id): metrics = self.junction_metrics.get(j_id, {}) total_delay = metrics.get("total_delay", 0.0) total_vehicles = metrics.get("total_vehicles", 0) return total_delay / total_vehicles if total_vehicles > 0 else 0.0 def get_all_avg_delay(self): total_delay = sum(metrics["total_delay"] for metrics in self.junction_metrics.values()) total_vehicles = sum(metrics["total_vehicles"] for metrics in self.junction_metrics.values()) return total_delay / total_vehicles if total_vehicles > 0 else 0.0 def get_junction_avg_waiting(self, j_id): metrics = self.junction_metrics.get(j_id, {}) total_waiting = metrics.get("total_waiting", 0.0) total_vehicles = metrics.get("total_vehicles", 0) return total_waiting / total_vehicles if total_vehicles > 0 else 0.0 def get_all_avg_waiting(self): total_waiting = sum(metrics["total_waiting"] for metrics in self.junction_metrics.values()) total_vehicles = sum(metrics["total_vehicles"] for metrics in self.junction_metrics.values()) return total_waiting / total_vehicles if total_vehicles > 0 else 0.0 def get_junction_avg_queue(self, j_id): metrics = self.junction_metrics.get(j_id, {}) total_queue = metrics.get("total_queue", 0) queue_count = metrics.get("queue_count", 0) return total_queue / queue_count if queue_count > 0 else 0.0 def get_all_avg_queue(self): total_queue = sum(metrics["total_queue"] for metrics in self.junction_metrics.values()) total_count = sum(metrics["queue_count"] for metrics in self.junction_metrics.values()) return total_queue / total_count if total_count > 0 else 0.0 def cal_waiting_time(self, frame_time, vehicle): v_id = vehicle["v_id"] # 注意:on_enter_lane定义在其他目录,此处保持原调用方式 if on_enter_lane(vehicle, self.get_invalid_lanes()): distance_to_stop = vehicle["position_in_lane"]["y"] if vehicle["speed"] <= 0.1 and distance_to_stop < 50: if v_id not in self.last_waiting_moment: self.last_waiting_moment[v_id] = frame_time self.waiting_time_store.setdefault(v_id, 0.0) else: waiting_duration = frame_time - self.last_waiting_moment[v_id] self.waiting_time_store[v_id] += waiting_duration self.last_waiting_moment[v_id] = frame_time else: self.last_waiting_moment.pop(v_id, None) else: self.waiting_time_store.pop(v_id, None) self.last_waiting_moment.pop(v_id, None) def cal_travel_distance(self, vehicle): v_id = vehicle["v_id"] # 注意:on_enter_lane定义在其他目录,此处保持原调用方式 if on_enter_lane(vehicle, self.get_invalid_lanes()): if self.vehicle_prev_junction.get(v_id, -1) != -1 and v_id in self.vehicle_distance_store: del self.vehicle_distance_store[v_id] if v_id not in self.vehicle_prev_position: current_pos = vehicle.get("position_in_lane", {}) if "x" in current_pos and "y" in current_pos and isinstance(current_pos["x"], (int, float)) and isinstance(current_pos["y"], (int, float)): # 核心修复1:初始化时补充 distance_to_stop 字段(取值自 position_in_lane["y"]) self.vehicle_prev_position[v_id] = { "x": current_pos["x"], "y": current_pos["y"], "distance_to_stop": current_pos["y"] # 补充该字段,与协议y的语义一致 } self.vehicle_distance_store.setdefault(v_id, 0.0) else: self.logger.warning(f"车辆 {v_id} 位置数据无效(x={current_pos.get('x')}, y={current_pos.get('y')}),无法初始化历史位置") return prev_pos = self.vehicle_prev_position[v_id] current_pos = vehicle["position_in_lane"] try: dx = current_pos["x"] - prev_pos["x"] dy = current_pos["y"] - prev_pos["y"] euclid_distance = math.hypot(dx, dy) # 核心修复2:用get方法访问,增加默认值(避免字段缺失) prev_distance_to_stop = prev_pos.get("distance_to_stop", current_pos["y"]) stop_distance_reduce = prev_distance_to_stop - current_pos["y"] self.vehicle_distance_store[v_id] = self.vehicle_distance_store.get(v_id, 0.0) + max(euclid_distance, stop_distance_reduce) # 同步更新历史位置的 distance_to_stop(避免下次访问仍缺失) self.vehicle_prev_position[v_id]["distance_to_stop"] = current_pos["y"] except Exception as e: self.logger.error(f"计算车辆 {v_id} 行驶距离失败: {str(e)},详细数据:prev_pos={prev_pos}, current_pos={current_pos}") else: self.vehicle_prev_position.pop(v_id, None) self.vehicle_distance_store.pop(v_id, None) def get_all_junction_waiting_time(self, vehicles): res = {j_id: 0.0 for j_id in self.junction_dict} v_num = {j_id: 0 for j_id in self.junction_dict} vehicle_type_weight = {1: 1.0, 2: 1.5, 3: 1.2, 4: 0.8, 5: 0.5} # 匹配协议VehicleType枚举值 for vehicle in vehicles: v_id = vehicle["v_id"] if (self.vehicle_status.get(v_id, 0) != 0 or vehicle["junction"] != -1 or vehicle["target_junction"] == -1): continue j_id = vehicle["target_junction"] if j_id not in self.junction_dict: continue v_config_id = vehicle.get("v_config_id", 1) v_type = self.vehicle_configs.get(v_config_id, {}).get("v_type", 1) weight = vehicle_type_weight.get(v_type, 1.0) res[j_id] += self.waiting_time_store.get(v_id, 0.0) * weight v_num[j_id] += 1 for j_id in self.junction_dict: if v_num[j_id] > 0: res[j_id] /= v_num[j_id] return res def get_phase_remaining_time(self, signal_id): return self.current_phases.get(signal_id, {}).get("remaining_duration", 0) def get_invalid_lanes(self): return self.accident_lanes.union(self.control_lanes) def is_peak_hour(self): return self.peak_hour def get_weather(self): return self.weather def get_sorted_junction_ids(self): return sorted(self.junction_dict.keys()) 根据以上分析,应该怎样修改此代码‘

#!/usr/bin/env python3 import numpy as np if not hasattr(np, 'float'): np.float = np.float32 from isaacgym import gymtorch from typing import Optional, Dict, Union, Mapping, Tuple, List, Any, Iterable from dataclasses import dataclass, InitVar, replace from pathlib import Path from copy import deepcopy from gym import spaces import tempfile import multiprocessing as mp import json import cv2 from functools import partial import numpy as np import torch as th import einops from torch.utils.tensorboard import SummaryWriter from pkm.models.common import (transfer, map_struct) # from pkm.models.rl.v4.rppo import ( # RecurrentPPO as PPO) from pkm.models.rl.v6.ppo import PPO from pkm.models.rl.generic_state_encoder import ( MLPStateEncoder) from pkm.models.rl.nets import ( VNet, PiNet, CategoricalPiNet, MLPFwdBwdDynLossNet ) # env + general wrappers # FIXME: ArmEnv _looks_ like a class, but it's # actually PushEnv + wrapper. from pkm.env.arm_env import (ArmEnv, ArmEnvConfig, OBS_BOUND_MAP, _identity_bound) from pkm.env.env.wrap.base import WrapperEnv from pkm.env.env.wrap.normalize_env import NormalizeEnv from pkm.env.env.wrap.monitor_env import MonitorEnv from pkm.env.env.wrap.adaptive_domain_tuner import MultiplyScalarAdaptiveDomainTuner from pkm.env.env.wrap.nvdr_camera_wrapper import NvdrCameraWrapper from pkm.env.env.wrap.popdict import PopDict from pkm.env.env.wrap.mvp_wrapper import MvpWrapper from pkm.env.env.wrap.normalize_img import NormalizeImg from pkm.env.env.wrap.tabulate_action import TabulateAction from pkm.env.env.wrap.nvdr_record_viewer import NvdrRecordViewer from pkm.env.env.wrap.nvdr_record_episode import NvdrRecordEpisode from pkm.env.util import set_seed from pkm.util.config import (ConfigBase, recursive_replace_map) from pkm.util.hydra_cli import hydra_cli from pkm.util.path import RunPath, ensure_directory from pkm.train.ckpt import last_ckpt, step_from_ckpt from pkm.train.hf_hub import (upload_ckpt, HfConfig, GroupConfig) from pkm.train.wandb import with_wandb, WandbConfig from pkm.train.util import ( assert_committed ) # domain-specific wrappers from envs.push_env_wrappers import ( AddGoalThreshFromPushTask ) from envs.cube_env_wrappers import ( AddObjectMass, AddPhysParams, AddPrevArmWrench, AddPrevAction, AddWrenchPenalty, AddObjectEmbedding, AddObjectKeypoint, AddObjectFullCloud, AddFingerFullCloud, AddApproxTouchFlag, AddTouchCount, AddSuccessAsObs, AddTrackingReward, QuatToDCM, QuatTo6D, RelGoal, Phase2Training, P2VembObs, ICPEmbObs, PNEmbObs ) # == drawing/debugging wrappers == from pkm.env.env.wrap.draw_bbox_kpt import DrawGoalBBoxKeypoint, DrawObjectBBoxKeypoint from pkm.env.env.wrap.draw_inertia_box import DrawInertiaBox from pkm.env.env.wrap.draw_clouds import DrawClouds from pkm.env.env.wrap.draw_patch_attn import DrawPatchAttention from envs.cube_env_wrappers import (DrawGoalPose, DrawObjPose, DrawTargetPose, DrawPosBound, DrawDebugLines) import nvtx from icecream import ic def to_pod(x: np.ndarray) -> List[float]: return [float(e) for e in x] @dataclass class PolicyConfig(ConfigBase): """ Actor-Critic policy configuration. """ actor: PiNet.Config = PiNet.Config() value: VNet.Config = VNet.Config() dim_state: InitVar[Optional[int]] = None dim_act: InitVar[Optional[int]] = None def __post_init__(self, dim_state: Optional[int] = None, dim_act: Optional[int] = None): if dim_state is not None: self.actor = replace(self.actor, dim_feat=dim_state) self.value = replace(self.value, dim_feat=dim_state) if dim_act is not None: self.actor = replace(self.actor, dim_act=dim_act) @dataclass class NetworkConfig(ConfigBase): """ Overall network configuration. """ state: MLPStateEncoder.Config = MLPStateEncoder.Config() policy: PolicyConfig = PolicyConfig() obs_space: InitVar[Union[int, Dict[str, int], None]] = None act_space: InitVar[Optional[int]] = None def __post_init__(self, obs_space=None, act_space=None): self.state = replace(self.state, obs_space=obs_space, act_space=act_space) try: if isinstance(act_space, Iterable) and len(act_space) == 1: act_space = act_space[0] policy = replace(self.policy, dim_state=self.state.state.dim_out, dim_act=act_space) self.policy = policy except AttributeError: pass @dataclass class Config(WandbConfig, HfConfig, GroupConfig, ConfigBase): # WandbConfig parts project: str = 'arm-ppo' use_wandb: bool = True # HfConfig (huggingface) parts hf_repo_id: Optional[str] = 'corn/corn-/arm' use_hfhub: bool = True # General experiment / logging force_commit: bool = False description: str = '' path: RunPath.Config = RunPath.Config(root='/tmp/pkm/ppo-arm/') env: ArmEnvConfig = ArmEnvConfig(which_robot='franka') agent: PPO.Config = PPO.Config() # State/Policy network configurations net: NetworkConfig = NetworkConfig() # Loading / continuing from prevous runs load_ckpt: Optional[str] = None transfer_ckpt: Optional[str] = None freeze_transferred: bool = True global_device: Optional[str] = None # VISION CONFIG use_camera: bool = False camera: NvdrCameraWrapper.Config = NvdrCameraWrapper.Config( use_depth=True, use_col=True, ctx_type='cuda', # == D435 config(?) == # aspect=8.0 / 5.0, # img_size=(480,848) # z_near for the physical camera # is actually pretty large! # z_near=0.195 # Horizontal Field of View 69.4 91.2 # Vertical Field of View 42.5 65.5 ) # Convert img into MVP-pretrained embeddings use_mvp: bool = False remove_state: bool = False remove_robot_state: bool = False remove_all_state: bool = False # Determines which inputs, even if they remain # in the observation dict, are not processed # by the state representation network. state_net_blocklist: Optional[List[str]] = None # FIXME: remove hide_action: # legacy config from train_ppo_hand.py hide_action: Optional[bool] = True add_object_mass: bool = False add_object_embedding: bool = False add_phys_params: bool = False add_keypoint: bool = False add_object_full_cloud: bool = False add_goal_full_cloud: bool = False add_finger_full_cloud: bool = False add_prev_wrench: bool = True add_prev_action: bool = True zero_out_prev_action: bool = False add_goal_thresh: bool = False add_wrench_penalty: bool = False wrench_penalty_coef: float = 1e-4 add_touch_flag: bool = False add_touch_count: bool = False min_touch_force: float = 5e-2 min_touch_speed: float = 1e-3 add_success: bool = False add_tracking_reward: bool = False # ==<CURRICULUM>== use_tune_init_pos: bool = False tune_init_pos_scale: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=1.05, easy=0.1, hard=1.0) use_tune_goal_radius: bool = False tune_goal_radius: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=0.95, easy=0.5, hard=0.05) use_tune_goal_speed: bool = False tune_goal_speed: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=0.95, easy=4.0, hard=0.1) use_tune_goal_angle: bool = False tune_goal_angle: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=0.95, easy=1.57, hard=0.05) use_tune_pot_gamma: bool = False tune_pot_gamma: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=0.999, easy=1.00, hard=0.99, step_down=1.001, metric='return', target_lower=0.0, target_upper=0.0) force_vel: Optional[float] = None force_rad: Optional[float] = None force_ang: Optional[float] = None # ==</CURRICULUM>== use_tabulate: bool = False tabulate: TabulateAction.Config = TabulateAction.Config( num_bin=3 ) use_norm: bool = True normalizer: NormalizeEnv.Config = NormalizeEnv.Config() # Convert some observations into # alternative forms... use_dcm: bool = False use_rel_goal: bool = False use_6d_rel_goal: bool = False use_monitor: bool = True monitor: MonitorEnv.Config = MonitorEnv.Config() # == camera config == use_nvdr_record_episode: bool = False nvdr_record_episode: NvdrRecordEpisode.Config = NvdrRecordEpisode.Config() use_nvdr_record_viewer: bool = False nvdr_record_viewer: NvdrRecordViewer.Config = NvdrRecordViewer.Config( img_size=(128, 128) ) normalize_img: bool = True img_mean: float = 0.4 img_std: float = 0.2 cloud_only: bool = False multiple_cameras: bool = False camera_eyes: Tuple[Any] = ( (-0.238, 0.388, 0.694), (-0.408, -0.328, 0.706) ) # == "special" training configs # add auxiliary dynamics netweork+loss add_dyn_aux: bool = False # automatic mixed-precision(FP16) training use_amp: bool = False # DataParallel training across multiple devices parallel: Optional[Tuple[int, ...]] = None # == periodic validation configs == sample_action: bool = False eval_period: int = -1 eval_step: int = 256 eval_num_env: int = 16 eval_record: bool = True eval_device: str = 'cuda:0' eval_track_per_obj_suc_rate: bool = False draw_debug_lines: bool = False draw_patch_attn: bool = False finalize: bool = False parallel: Optional[Tuple[int, ...]] = None is_phase2: bool = False phase2: Phase2Training.Config = Phase2Training.Config() use_p2v: bool = False use_icp_obs: bool = False use_pn_obs: bool = False p2v: P2VembObs.Config = P2VembObs.Config() icp_obs: ICPEmbObs.Config = ICPEmbObs.Config() pn_obs: PNEmbObs.Config = PNEmbObs.Config() def __post_init__(self): self.group = F'{self.machine}-{self.env_name}-{self.model_name}-{self.tag}' self.name = F'{self.group}-{self.env.seed:06d}' if not self.finalize: return # WARNING: VERY HAZARDOUS use_dr_on_setup = self.env.single_object_scene.use_dr_on_setup | self.is_phase2 use_dr = self.env.single_object_scene.use_dr | self.is_phase2 self.env = recursive_replace_map( self.env, {'franka.compute_wrench': self.add_prev_wrench, 'franka.add_control_noise': self.is_phase2, 'single_object_scene.use_dr_on_setup': use_dr_on_setup, 'single_object_scene.use_dr': use_dr, }) if self.global_device is not None: dev_id: int = int(str(self.global_device).split(':')[-1]) self.env = recursive_replace_map(self.env, { 'graphics_device_id': (dev_id if self.env.use_viewer else -1), 'compute_device_id': dev_id, 'th_device': self.global_device, }) self.agent = recursive_replace_map(self.agent, { 'device': self.global_device}) if self.force_vel is not None: self.use_tune_goal_speed = False self.env.task.max_speed = self.force_vel if self.force_rad is not None: self.use_tune_goal_radius = False self.env.task.goal_radius = self.force_rad if self.force_ang is not None: self.use_tune_goal_angle = False self.env.task.goal_angle = self.force_ang def setup(cfg: Config): # Maybe it's related to jit if cfg.global_device is not None: th.cuda.set_device(cfg.global_device) th.backends.cudnn.benchmark = True commit_hash = assert_committed(force_commit=cfg.force_commit) path = RunPath(cfg.path) print(F'run = {path.dir}') return path class AddTensorboardWriter(WrapperEnv): def __init__(self, env): super().__init__(env) self._writer = None def set_writer(self, w): self._writer = w @property def writer(self): return self._writer def load_env(cfg: Config, path, freeze_env: bool = False, **kwds): env = ArmEnv(cfg.env) env.setup() env.gym.prepare_sim(env.sim) env.refresh_tensors() env.reset() env = AddTensorboardWriter(env) obs_bound = None if cfg.use_norm: obs_bound = {} # Populate obs_bound with defaults # from ArmEnv. obs_bound['goal'] = OBS_BOUND_MAP.get(cfg.env.goal_type) obs_bound['object_state'] = OBS_BOUND_MAP.get( cfg.env.object_state_type) obs_bound['hand_state'] = OBS_BOUND_MAP.get(cfg.env.hand_state_type) obs_bound['robot_state'] = OBS_BOUND_MAP.get(cfg.env.robot_state_type) if cfg.normalizer.norm.stats is not None: obs_bound.update(deepcopy(cfg.normalizer.norm.stats)) print(obs_bound) def __update_obs_bound(key, value, obs_bound, overwrite: bool = True): if not cfg.use_norm: return if value is None: obs_bound.pop(key, None) if key in obs_bound: if overwrite: print(F'\t WARN: key = {key} already in obs_bound !') else: raise ValueError(F'key = {key} already in obs_bound !') obs_bound[key] = value update_obs_bound = partial(__update_obs_bound, obs_bound=obs_bound) if cfg.env.task.use_pose_goal: if cfg.add_goal_full_cloud: update_obs_bound('goal', OBS_BOUND_MAP.get('cloud')) else: update_obs_bound('goal', OBS_BOUND_MAP.get(cfg.env.goal_type)) # Crude check for mutual exclusion # Determines what type of privileged "state" information # the policy will receive, as observation. assert ( np.count_nonzero( [cfg.remove_state, cfg.remove_robot_state, cfg.remove_all_state]) <= 1) if cfg.remove_state: env = PopDict(env, ['object_state']) update_obs_bound('object_state', None) elif cfg.remove_robot_state: env = PopDict(env, ['hand_state']) update_obs_bound('hand_state', None) elif cfg.remove_all_state: env = PopDict(env, ['hand_state', 'object_state']) update_obs_bound('hand_state', None) update_obs_bound('object_state', None) if cfg.add_object_mass: env = AddObjectMass(env, 'object_mass') update_obs_bound('object_mass', OBS_BOUND_MAP.get('mass')) if cfg.add_phys_params: env = AddPhysParams(env, 'phys_params') update_obs_bound('phys_params', OBS_BOUND_MAP.get('phys_params')) if cfg.add_object_embedding: env = AddObjectEmbedding(env, 'object_embedding') update_obs_bound('object_embedding', OBS_BOUND_MAP.get('embedding')) if cfg.add_keypoint: env = AddObjectKeypoint(env, 'object_keypoint') update_obs_bound('object_keypoint', OBS_BOUND_MAP.get('keypoint')) if cfg.add_object_full_cloud: # mutually exclusive w.r.t. use_cloud # i.e. the partial point cloud coming from # the camera. # assert (cfg.camera.use_cloud is False) goal_key = None if cfg.add_goal_full_cloud: goal_key = 'goal' env = AddObjectFullCloud(env, 'cloud', goal_key=goal_key) update_obs_bound('cloud', OBS_BOUND_MAP.get('cloud')) if goal_key is not None: update_obs_bound(goal_key, OBS_BOUND_MAP.get('cloud')) if cfg.add_finger_full_cloud: env = AddFingerFullCloud(env, 'finger_cloud') update_obs_bound('finger_cloud', OBS_BOUND_MAP.get('cloud')) if cfg.add_goal_thresh: env = AddGoalThreshFromPushTask(env, key='goal_thresh', dim=3) update_obs_bound('goal_thresh', _identity_bound(3)) if cfg.add_prev_wrench: env = AddPrevArmWrench(env, 'previous_wrench') update_obs_bound('previous_wrench', OBS_BOUND_MAP.get('wrench')) if cfg.add_prev_action: env = AddPrevAction(env, 'previous_action', zero_out=cfg.zero_out_prev_action) update_obs_bound('previous_action', _identity_bound( env.observation_space['previous_action'].shape )) if cfg.add_wrench_penalty: env = AddWrenchPenalty(env, cfg.wrench_penalty_coef, key='env/wrench_cost') if cfg.add_touch_flag: env = AddApproxTouchFlag(env, key='touch', min_force=cfg.min_touch_force, min_speed=cfg.min_touch_speed) if cfg.add_touch_count: assert (cfg.add_touch_flag) env = AddTouchCount(env, key='touch_count') update_obs_bound('touch_count', _identity_bound( env.observation_space['touch_count'].shape )) if cfg.add_success: env = AddSuccessAsObs(env, key='success') update_obs_bound('success', _identity_bound(())) if cfg.use_camera: prev_space_keys = deepcopy(list(env.observation_space.keys())) env = NvdrCameraWrapper( env, cfg.camera ) for k in env.observation_space.keys(): if k in prev_space_keys: continue obs_shape = env.observation_space[k].shape # if k in cfg.normalizer.obs_shape: # obs_shape = cfg.normalizer.obs_shape[k] print(k, obs_shape) if 'cloud' in k: update_obs_bound(k, OBS_BOUND_MAP.get('cloud')) else: update_obs_bound(k, _identity_bound(obs_shape[-1:])) if cfg.multiple_cameras: camera = deepcopy(cfg.camera) camera = replace( camera, use_label=False ) for i, eye in enumerate(cfg.camera_eyes): cloud_key = f'partial_cloud_{i+1}' new_camera = replace( camera, eye=eye ) new_camera = replace( new_camera, key_cloud=cloud_key ) env = NvdrCameraWrapper( env, new_camera ) update_obs_bound(cloud_key, OBS_BOUND_MAP.get('cloud')) if cfg.normalize_img: env = NormalizeImg(env, cfg.img_mean, cfg.img_std, key='depth') # After normalization, it (should) map to (0.0, 1.0) update_obs_bound('depth', (0.0, 1.0)) if cfg.cloud_only: env = PopDict(env, ['depth', 'label']) update_obs_bound('depth', None) update_obs_bound('label', None) if cfg.use_mvp: assert (cfg.use_camera) env = MvpWrapper(env) raise ValueError( 'MVPWrapper does not currently configure a proper obs space.' ) if cfg.add_tracking_reward: env = AddTrackingReward(env, 1e-4) # == curriculum == if cfg.use_tune_init_pos: def get_init_pos_scale(): return env.scene._pos_scale def set_init_pos_scale(s: float): env.scene._pos_scale = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_init_pos_scale, env, get_init_pos_scale, set_init_pos_scale, key='env/init_pos_scale') if cfg.use_tune_goal_radius: def get_goal_rad(): return env.task.goal_radius def set_goal_rad(s: float): env.task.goal_radius = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_goal_radius, env, get_goal_rad, set_goal_rad, key='env/goal_radius') if cfg.use_tune_goal_speed: def get_goal_speed(): return env.task.max_speed def set_goal_speed(s: float): env.task.max_speed = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_goal_speed, env, get_goal_speed, set_goal_speed, key='env/max_speed') if cfg.use_tune_goal_angle: def get_goal_ang(): return env.task.goal_angle def set_goal_ang(s: float): env.task.goal_angle = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_goal_angle, env, get_goal_ang, set_goal_ang, key='env/goal_angle') if cfg.use_tune_pot_gamma: def get_pot_gamma(): return env.task.gamma def set_pot_gamma(s: float): env.task.gamma = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_pot_gamma, env, get_pot_gamma, set_pot_gamma, key='env/pot_gamma') if cfg.use_tabulate: env = TabulateAction(cfg.tabulate, env) if cfg.use_dcm: env = QuatToDCM(env, { 'goal': 3, 'hand_state': 3, 'object_state': 3 }) raise ValueError( 'DCM (directional cosine matrix) conversions are ' 'currently disabled due to complex integration ' 'with obs_bound.') # Use relative goal between current object pose # and the goal pose, instead of absolute goal. if cfg.use_rel_goal: env = RelGoal(env, 'goal', use_6d=cfg.use_6d_rel_goal) if cfg.use_6d_rel_goal: update_obs_bound('goal', OBS_BOUND_MAP.get('relpose6d')) else: update_obs_bound('goal', OBS_BOUND_MAP.get('relpose')) if cfg.is_phase2: env = Phase2Training(cfg.phase2, env) # == DRAW, LOG, RECORD == if cfg.draw_debug_lines: check_viewer = kwds.pop('check_viewer', True) env = DrawDebugLines(DrawDebugLines.Config( draw_workspace=kwds.pop('draw_workspace', False), draw_wrench_target=kwds.pop('draw_wrench_target', False), draw_cube_action=kwds.pop('draw_hand_action', False) ), env, check_viewer=check_viewer) # NOTE: blocklist=0 indicates the table; # blocklist=2 indicates the robot. Basically, # only draw the inertia-box for the object. env = DrawInertiaBox(env, blocklist=[0, 2], check_viewer=check_viewer) env = DrawObjectBBoxKeypoint(env) env = DrawGoalBBoxKeypoint(env) env = DrawGoalPose(env, check_viewer=check_viewer) env = DrawObjPose(env, check_viewer=check_viewer) # Some alternative visualizations are available below; # [1] draw the goal as a "pose" frame axes # env = DrawTargetPose(env, # check_viewer=check_viewer) # [2] Draw franka EE boundary if cfg.env.franka.track_object: env = DrawPosBound(env, check_viewer=check_viewer) # [3] Draw input point cloud observations as spheres. # Should usually be prevented, so check_viewer=True # env = DrawClouds(env, check_viewer=True, stride=8, # cloud_key='partial_cloud', # or 'cloud' # style='ray') if cfg.draw_patch_attn: class PatchAttentionFromPPV5: """ Retrieve patchified point cloud and attention values from PointPatchV5FeatNet. """ def __init__(self): # self.__net = agent.state_net.feature_encoders['cloud'] self.__net = None def register(self, net): self.__net = net def __call__(self, obs): ravel_index = self.__net._patch_index.reshape( *obs['cloud'].shape[:-2], -1, 1) patch = th.take_along_dim( # B, N, D obs['cloud'], # B, (S, P), 1 ravel_index, dim=-2 ).reshape(*self.__net._patch_index.shape, obs['cloud'].shape[-1]) attn = self.__net._patch_attn # ic(attn) # Only include parts that correspond to # point patches # ic('pre',attn.shape) # attn = attn[..., 1:, :] attn = attn[..., :, 1:] # ic('post',attn.shape) # max among heads # attn = attn.max(dim=-2).values # head zero attn = attn[..., 2, :] return (patch, attn) env = DrawPatchAttention(env, PatchAttentionFromPPV5(), dilate=1.2, style='cloud') if cfg.use_nvdr_record_viewer: env = NvdrRecordViewer(cfg.nvdr_record_viewer, env, hide_arm=False) # == MONITOR PERFORMANCE == if cfg.use_monitor: env = MonitorEnv(cfg.monitor, env) # == Normalize environment == # normalization must come after # the monitoring code, since it # overwrites env statistics. if cfg.use_norm: cfg = recursive_replace_map(cfg, {'normalizer.norm.stats': obs_bound}) env = NormalizeEnv(cfg.normalizer, env, path) if cfg.load_ckpt is not None: ckpt_path = Path(cfg.load_ckpt) if ckpt_path.is_file(): # Try to select stats from matching timestep. step = ckpt_path.stem.split('-')[-1] def ckpt_key(ckpt_file): return (step in str(ckpt_file.stem).rsplit('-')[-1]) stat_dir = ckpt_path.parent / '../stat/' else: # Find the latest checkpoint. ckpt_key = step_from_ckpt stat_dir = ckpt_path / '../stat' if stat_dir.is_dir(): stat_ckpt = last_ckpt(stat_dir, key=ckpt_key) print(F'Also loading env stats from {stat_ckpt}') env.load(stat_ckpt, strict=False) # we'll freeze env stats by default, if loading from ckpt. if freeze_env: env.normalizer.eval() else: stat_ckpt = last_ckpt(cfg.load_ckpt + "_stat", key=ckpt_key) print(F'Also loading env stats from {stat_ckpt}') env.load(stat_ckpt, strict=False) if cfg.use_p2v: env = P2VembObs(env, cfg.p2v) env = PopDict(env, ['cloud']) update_obs_bound('cloud', None) if cfg.use_icp_obs: env = ICPEmbObs(env, cfg.icp_obs) env = PopDict(env, ['cloud']) update_obs_bound('cloud', None) if cfg.use_pn_obs: env = PNEmbObs(env, cfg.pn_obs) env = PopDict(env, ['cloud']) update_obs_bound('cloud', None) return cfg, env def load_agent(cfg, env, path, writer): device = env.device ic(cfg) # FIXME: We currently disable MLPStateEncoder from # receiving previous_action implicitly; it has to be # included in the observations explicitly. cfg.net.state.state.dim_act = 0 state_net = MLPStateEncoder.from_config(cfg.net.state) # Create policy/value networks. # FIXME: introspection into cfg.dim_out dim_state = state_net.state_aggregator.cfg.dim_out if isinstance(env.action_space, spaces.Discrete): actor_net = CategoricalPiNet(cfg.net.policy.actor).to(device) else: actor_net = PiNet(cfg.net.policy.actor).to(device) value_net = VNet(cfg.net.policy.value).to(device) # Add extra networks (Usually for regularization, # auxiliary losses, or learning extra models) extra_nets = None if cfg.add_dyn_aux: trans_net_cfg = MLPFwdBwdDynLossNet.Config( dim_state=dim_state, dim_act=cfg.net.policy.actor.dim_act, dim_hidden=(128,), ) trans_net = MLPFwdBwdDynLossNet(trans_net_cfg).to(device) extra_nets = {'trans_net': trans_net} agent = PPO( cfg.agent, env, state_net, actor_net, value_net, path, writer, extra_nets=extra_nets ).to(device) if cfg.transfer_ckpt is not None: ckpt = last_ckpt(cfg.transfer_ckpt, key=step_from_ckpt) xfer_dict = th.load(ckpt, map_location='cpu') keys = transfer(agent, xfer_dict['self'], freeze=cfg.freeze_transferred, substrs=[ # 'state_net.feature_encoders', # 'state_net.feature_aggregators' 'state_net' ], # prefix_map={ # 'state_net.feature_encoders.state': # 'state_net.feature_encoders.object_state', # 'state_net.feature_aggregators.state': # 'state_net.feature_aggregators.object_state', # }, verbose=True) print(keys) if cfg.load_ckpt is not None: ckpt: str = last_ckpt(cfg.load_ckpt, key=step_from_ckpt) print(F'Load agent from {ckpt}') agent.load(last_ckpt(cfg.load_ckpt, key=step_from_ckpt), strict=True) return agent def eval_agent_inner(cfg: Config, return_dict): # [1] Silence outputs during validation. import sys import os sys.stdout = open(os.devnull, 'w') sys.stderr = open(os.devnull, 'w') # [2] Import & run validation. from valid_ppo_arm import main return_dict.update(main(cfg)) def eval_agent(cfg: Config, env, agent: PPO): # subprocess.check_output('python3 valid_ppo_hand.py ++run=') manager = mp.Manager() return_dict = manager.dict() with tempfile.TemporaryDirectory() as tmpdir: # Save agent ckpt for validation. ckpt_dir = ensure_directory(F'{tmpdir}/ckpt') stat_dir = ensure_directory(F'{tmpdir}/stat') agent_ckpt = str(ckpt_dir / 'last.ckpt') env_ckpt = str(stat_dir / 'env-last.ckpt') env.save(env_ckpt) agent.save(agent_ckpt) # Override cfg. # FIXME: # Hardcoded target_domain coefs # shold potentially be # tune_goal_speed.hard... # etc. cfg = recursive_replace_map(cfg, { 'load_ckpt': str(ckpt_dir), 'force_vel': 0.1, 'force_rad': 0.05, 'force_ang': 0.1, 'env.num_env': cfg.eval_num_env, 'env.use_viewer': False, 'env.single_object_scene.num_object_types': ( cfg.env.single_object_scene.num_object_types), 'monitor.verbose': False, 'draw_debug_lines': True, 'use_nvdr_record_viewer': cfg.eval_record, 'nvdr_record_viewer.record_dir': F'{tmpdir}/record', 'env.task.mode': 'valid', 'env.single_object_scene.mode': 'valid', 'env.single_object_scene.num_valid_poses': 4, 'global_device': cfg.eval_device, 'eval_track_per_obj_suc_rate': True }) ctx = mp.get_context('spawn') # Run. proc = ctx.Process( target=eval_agent_inner, args=(cfg, return_dict), ) proc.start() proc.join() return_dict = dict(return_dict) if 'video' in return_dict: replaced = {} for k, v in return_dict['video'].items(): if isinstance(v, str): video_dir = v assert Path(video_dir).is_dir() filenames = sorted(Path(video_dir).glob('*.png')) rgb_images = [cv2.imread(str(x))[..., ::-1] for x in filenames] vid_array = np.stack(rgb_images, axis=0) v = th.as_tensor(vid_array[None]) v = einops.rearrange(v, 'n t h w c -> n t c h w') replaced[k] = v return_dict['video'] = replaced return return_dict @with_wandb def inner_main(cfg: Config, env, path): """ Basically it's the same as main(), but we commit the config _after_ finalizing. """ commit_hash = assert_committed(force_commit=cfg.force_commit) writer = SummaryWriter(path.tb_train) writer.add_text('meta/commit-hash', str(commit_hash), global_step=0) env.unwrap(target=AddTensorboardWriter).set_writer(writer) agent = load_agent(cfg, env, path, writer) # Enable DataParallel() for subset of modules. if (cfg.parallel is not None) and (th.cuda.device_count() > 1): count: int = th.cuda.device_count() device_ids = list(cfg.parallel) # FIXME: hardcoded DataParallel processing only for # img feature if 'img' in agent.state_net.feature_encoders: agent.state_net.feature_encoders['img'] = th.nn.DataParallel( agent.state_net.feature_encoders['img'], device_ids) ic(agent) def __eval(step: int): logs = eval_agent(cfg, env, agent) log_kwds = {'video': {'fps': 20.0}} # == generic log() == for log_type, log in logs.items(): for tag, value in log.items(): write = getattr(writer, F'add_{log_type}') write(tag, value, global_step=step, **log_kwds.get(log_type, {})) try: th.cuda.empty_cache() with th.cuda.amp.autocast(enabled=cfg.use_amp): for step in agent.learn(name=F'{cfg.name}@{path.dir}'): # Periodically run validation. if (cfg.eval_period > 0) and (step % cfg.eval_period) == 0: th.cuda.empty_cache() __eval(step) finally: # Dump final checkpoints. agent.save(path.ckpt / 'last.ckpt') if hasattr(env, 'save'): env.save(path.stat / 'env-last.ckpt') # Finally, upload the trained model to huggingface model hub. if cfg.use_hfhub and (cfg.hf_repo_id is not None): upload_ckpt( cfg.hf_repo_id, (path.ckpt / 'last.ckpt'), cfg.name) upload_ckpt( cfg.hf_repo_id, (path.stat / 'env-last.ckpt'), cfg.name + '_stat') @hydra_cli( config_path='../../src/pkm/data/cfg/', # config_path='/home/user/mambaforge/envs/genom/lib/python3.8/site-packages/pkm/data/cfg/', config_name='train_rl') def main(cfg: Config): ic.configureOutput(includeContext=True) cfg = recursive_replace_map(cfg, {'finalize': True}) # path, writer = setup(cfg) path = setup(cfg) seed = set_seed(cfg.env.seed) cfg, env = load_env(cfg, path) # Save object names... useful for debugging if True: with open(F'{path.stat}/obj_names.json', 'w') as fp: json.dump(env.scene.cur_names, fp) # Update cfg elements from env. obs_space = map_struct( env.observation_space, lambda src, _: src.shape, base_cls=spaces.Box, dict_cls=(Mapping, spaces.Dict) ) if cfg.state_net_blocklist is not None: for key in cfg.state_net_blocklist: obs_space.pop(key, None) dim_act = ( env.action_space.shape[0] if isinstance( env.action_space, spaces.Box) else env.action_space.n) cfg = replace(cfg, net=replace(cfg.net, obs_space=obs_space, act_space=dim_act, )) return inner_main(cfg, env, path) if __name__ == '__main__': main() 逐行分析我这里的代码,不要分析别的

import os import json import time import random import threading import datetime import pyautogui import keyboard from tkinter import * from tkinter import ttk, messagebox, filedialog from tkinter.font import Font class ActionRecorder: def __init__(self): self.recording = False self.playing = False self.actions = [] self.current_config = { "random_interval": 100, "loop_times": 1, "script_name": "未命名脚本" } self.configs = {} self.load_configs() # 创建日志目录 if not os.path.exists("logs"): os.makedirs("logs") # 初始化GUI self.init_gui() # 设置快捷键 self.setup_hotkeys() # 鼠标和键盘状态跟踪 self.mouse_state = { "left": False, "right": False, "middle": False } self.keyboard_state = {} # 鼠标和键盘监听器 self.keyboard_hook = None self.mouse_hook = None # 忽略的按键 self.ignored_keys = {'f1', 'f2'} def init_gui(self): self.root = Tk() self.root.title("动作录制器 v3.3") self.root.geometry("800x700") # 设置字体 bold_font = Font(family="微软雅黑", size=10, weight="bold") normal_font = Font(family="微软雅黑", size=9) # 配置区域 config_frame = LabelFrame(self.root, text="配置", padx=10, pady=10, font=bold_font) config_frame.pack(fill="x", padx=10, pady=5) # 脚本名称 Label(config_frame, text="脚本名称:", font=normal_font).grid(row=0, column=0, sticky="e", pady=5) self.script_name_entry = Entry(config_frame, font=normal_font) self.script_name_entry.grid(row=0, column=1, sticky="we", padx=5, pady=5) self.script_name_entry.insert(0, self.current_config["script_name"]) # 随机间隔 Label(config_frame, text="随机间隔(ms):", font=normal_font).grid(row=1, column=0, sticky="e", pady=5) self.interval_entry = Entry(config_frame, font=normal_font) self.interval_entry.grid(row=1, column=1, sticky="we", padx=5, pady=5) self.interval_entry.insert(0, str(self.current_config["random_interval"])) # 循环次数 Label(config_frame, text="循环次数:", font=normal_font).grid(row=2, column=0, sticky="e", pady=5) self.loop_entry = Entry(config_frame, font=normal_font) self.loop_entry.grid(row=2, column=1, sticky="we", padx=5, pady=5) self.loop_entry.insert(0, str(self.current_config["loop_times"])) # 按钮区域 button_frame = Frame(self.root) button_frame.pack(fill="x", padx=10, pady=10) self.record_btn = Button(button_frame, text="开始/停止录制 (F1)", command=self.toggle_recording, font=bold_font, bg="#4CAF50", fg="white") self.record_btn.pack(side="left", padx=5, ipadx=10, ipady=5) self.play_btn = Button(button_frame, text="开始/停止执行 (F2)", command=self.toggle_playing, font=bold_font, bg="#2196F3", fg="white") self.play_btn.pack(side="left", padx=5, ipadx=10, ipady=5) self.save_btn = Button(button_frame, text="保存脚本", command=self.save_script, font=bold_font, bg="#9C27B0", fg="white") self.save_btn.pack(side="right", padx=5, ipadx=10, ipady=5) self.delete_btn = Button(button_frame, text="删除脚本", command=self.delete_script, font=bold_font, bg="#607D8B", fg="white") self.delete_btn.pack(side="right", padx=5, ipadx=10, ipady=5) # 主内容区域 main_frame = Frame(self.root) main_frame.pack(fill="both", expand=True, padx=10, pady=5) # 脚本列表 script_frame = LabelFrame(main_frame, text="保存的脚本", padx=10, pady=10, font=bold_font) script_frame.pack(side="left", fill="y", padx=5, pady=5) self.script_listbox = Listbox(script_frame, font=normal_font, width=25, height=15) self.script_listbox.pack(fill="both", expand=True, padx=5, pady=5) self.script_listbox.bind("<Double-Button-1>", self.load_script) # 日志区域 log_frame = LabelFrame(main_frame, text="操作日志", padx=10, pady=10, font=bold_font) log_frame.pack(side="right", fill="both", expand=True, padx=5, pady=5) self.log_text = Text(log_frame, font=normal_font, wrap=WORD) scrollbar = Scrollbar(log_frame, command=self.log_text.yview) self.log_text.configure(yscrollcommand=scrollbar.set) scrollbar.pack(side="right", fill="y") self.log_text.pack(fill="both", expand=True, padx=5, pady=5) # 状态栏 self.status_var = StringVar() self.status_var.set("就绪") status_bar = Label(self.root, textvariable=self.status_var, bd=1, relief=SUNKEN, anchor=W, font=normal_font) status_bar.pack(fill="x", padx=10, pady=5) # 更新脚本列表 self.update_script_list() def setup_hotkeys(self): keyboard.add_hotkey('f1', self.toggle_recording, suppress=True) keyboard.add_hotkey('f2', self.toggle_playing, suppress=True) def toggle_recording(self): if self.recording: self.stop_recording() else: self.start_recording() def toggle_playing(self): if self.playing: self.stop_playing() else: self.start_playing() def log_message(self, message): timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") log_entry = f"[{timestamp}] {message}\n" self.log_text.insert(END, log_entry) self.log_text.see(END) self.status_var.set(message) # 写入日志文件 today = datetime.datetime.now().strftime("%Y-%m-%d") log_file = f"logs/{today}.log" with open(log_file, "a", encoding="utf-8") as f: f.write(log_entry) def save_configs(self): with open("configs.json", "w", encoding="utf-8") as f: json.dump(self.configs, f, ensure_ascii=False, indent=2) def load_configs(self): try: if os.path.exists("configs.json"): with open("configs.json", "r", encoding="utf-8") as f: self.configs = json.load(f) except Exception as e: self.log_message(f"加载配置失败: {str(e)}") self.configs = {} def update_current_config(self): try: self.current_config["script_name"] = self.script_name_entry.get() self.current_config["random_interval"] = int(self.interval_entry.get()) self.current_config["loop_times"] = int(self.loop_entry.get()) return True except ValueError: messagebox.showerror("错误", "请输入有效的数字") return False def save_script(self): if not self.update_current_config(): return if not self.actions: messagebox.showwarning("警告", "没有录制的动作可以保存") return script_name = self.current_config["script_name"] self.configs[script_name] = { "config": self.current_config, "actions": self.actions } self.save_configs() self.update_script_list() self.log_message(f"脚本 '{script_name}' 已保存") def delete_script(self): selection = self.script_listbox.curselection() if not selection: messagebox.showwarning("警告", "请先选择一个脚本") return script_name = self.script_listbox.get(selection[0]) if messagebox.askyesno("确认", f"确定要删除脚本 '{script_name}' 吗?"): if script_name in self.configs: del self.configs[script_name] self.save_configs() self.update_script_list() self.log_message(f"脚本 '{script_name}' 已删除") def update_script_list(self): self.script_listbox.delete(0, END) for script_name in sorted(self.configs.keys()): self.script_listbox.insert(END, script_name) def load_script(self, event=None): selection = self.script_listbox.curselection() if not selection: return script_name = self.script_listbox.get(selection[0]) if script_name in self.configs: script_data = self.configs[script_name] self.current_config = script_data["config"] self.actions = script_data["actions"] # 更新UI self.script_name_entry.delete(0, END) self.script_name_entry.insert(0, self.current_config["script_name"]) self.interval_entry.delete(0, END) self.interval_entry.insert(0, str(self.current_config["random_interval"])) self.loop_entry.delete(0, END) self.loop_entry.insert(0, str(self.current_config["loop_times"])) self.log_message(f"已加载脚本 '{script_name}'") self.log_message(f"共 {len(self.actions)} 个动作") def start_recording(self): if self.recording or self.playing: return if not self.update_current_config(): return self.recording = True self.actions = [] self.record_btn.config(bg="#F44336", text="停止录制 (F1)") self.play_btn.config(state=DISABLED) self.save_btn.config(state=DISABLED) self.delete_btn.config(state=DISABLED) # 重置鼠标和键盘状态 self.mouse_state = { "left": False, "right": False, "middle": False } self.keyboard_state = {} # 设置键盘和鼠标钩子 self.keyboard_hook = keyboard.hook(self.on_key_event) self.mouse_hook = keyboard.hook(self.on_mouse_event) self.log_message("开始录制...") self.log_message("请开始您的操作,按F1停止录制") self.log_message("正在录制: 鼠标点击、滚轮和键盘操作") def stop_recording(self): if not self.recording: return self.recording = False self.record_btn.config(bg="#4CAF50", text="开始录制 (F1)") self.play_btn.config(state=NORMAL) self.save_btn.config(state=NORMAL) self.delete_btn.config(state=NORMAL) # 移除钩子 if self.keyboard_hook: keyboard.unhook(self.keyboard_hook) if self.mouse_hook: keyboard.unhook(self.mouse_hook) self.log_message(f"停止录制,共录制了 {len(self.actions)} 个动作") self.save_script() def on_mouse_event(self, event): if not self.recording: return current_pos = pyautogui.position() if event.event_type in ('down', 'up'): button_mapping = { "left": "left", "right": "right", "middle": "middle", "x1": "x1", "x2": "x2" } button = button_mapping.get(event.name, "left") # 默认使用左键 action_type = "mousedown" if event.event_type == 'down' else "mouseup" self.actions.append({ "type": action_type, "button": button, "x": current_pos.x, "y": current_pos.y, "time": time.time() }) action_desc = f"鼠标{'按下' if event.event_type == 'down' else '释放'}: {button}键" self.log_message(f"{action_desc} 位置: ({current_pos.x}, {current_pos.y})") elif event.event_type == 'wheel': direction = "上" if event.delta > 0 else "下" self.actions.append({ "type": "wheel", "delta": event.delta, "x": current_pos.x, "y": current_pos.y, "time": time.time() }) self.log_message(f"滚轮滚动: 方向 {direction} 位置: ({current_pos.x}, {current_pos.y})") def on_key_event(self, event): if not self.recording: return # 忽略F1/F2按键 if event.name.lower() in self.ignored_keys: return if event.event_type == "down": # 只记录第一次按下,不记录重复按下 if event.name not in self.keyboard_state: self.keyboard_state[event.name] = True self.actions.append({ "type": "keydown", "key": event.name, "time": time.time() }) self.log_message(f"按键按下: {event.name}") elif event.event_type == "up": if event.name in self.keyboard_state: del self.keyboard_state[event.name] self.actions.append({ "type": "keyup", "key": event.name, "time": time.time() }) self.log_message(f"按键释放: {event.name}") def start_playing(self): if self.playing or self.recording or not self.actions: return if not self.update_current_config(): return self.playing = True self.play_btn.config(bg="#FF9800", text="停止执行 (F2)") self.record_btn.config(state=DISABLED) self.save_btn.config(state=DISABLED) self.delete_btn.config(state=DISABLED) # 开始执行线程 threading.Thread(target=self.play_actions, daemon=True).start() self.log_message("开始执行脚本...") def stop_playing(self): if not self.playing: return self.playing = False self.play_btn.config(bg="#2196F3", text="开始执行 (F2)") self.record_btn.config(state=NORMAL) self.save_btn.config(state=NORMAL) self.delete_btn.config(state=NORMAL) self.log_message("停止执行脚本") def play_actions(self): loop_times = self.current_config["loop_times"] random_interval = self.current_config["random_interval"] / 1000.0 # 转换为秒 for loop in range(loop_times): if not self.playing: break self.log_message(f"开始第 {loop + 1} 次循环 (共 {loop_times} 次)") prev_time = None for i, action in enumerate(self.actions): if not self.playing: break # 计算延迟 if prev_time is not None: delay = action["time"] - prev_time # 添加随机间隔 if random_interval > 0: delay += random.uniform(0, random_interval) time.sleep(max(0, delay)) prev_time = action["time"] # 执行动作 try: if action["type"] == "mousedown": # 确保只使用有效的按钮参数 button = action["button"] if action["button"] in ('left', 'middle', 'right') else 'left' pyautogui.mouseDown(x=action["x"], y=action["y"], button=button) self.log_message( f"执行动作 {i + 1}/{len(self.actions)}: 在({action['x']}, {action['y']})按下 {button}键") elif action["type"] == "mouseup": button = action["button"] if action["button"] in ('left', 'middle', 'right') else 'left' pyautogui.mouseUp(x=action["x"], y=action["y"], button=button) self.log_message( f"执行动作 {i + 1}/{len(self.actions)}: 在({action['x']}, {action['y']})释放 {button}键") elif action["type"] == "wheel": pyautogui.scroll(action["delta"]) self.log_message( f"执行动作 {i + 1}/{len(self.actions)}: 滚轮滚动 {'上' if action['delta'] > 0 else '下'}") elif action["type"] == "keydown": keyboard.press(action["key"]) self.log_message(f"执行动作 {i + 1}/{len(self.actions)}: 按下 {action['key']}键") elif action["type"] == "keyup": keyboard.release(action["key"]) self.log_message(f"执行动作 {i + 1}/{len(self.actions)}: 释放 {action['key']}键") except Exception as e: self.log_message(f"执行动作时出错: {str(e)}") if loop < loop_times - 1 and self.playing: time.sleep(1) # 循环之间的间隔 self.stop_playing() def run(self): self.root.mainloop() if __name__ == "__main__": recorder = ActionRecorder() recorder.run() 请帮我检查这个键盘鼠标脚本录制器错误地方,并帮我修改!

请详细介绍下这段代码”import os # os.environ['CUDA_VISIBLE_DEVICES'] = '3' import argparse import pickle from collections import defaultdict from datetime import datetime import numpy as np from mmcv.cnn import get_model_complexity_info from torch.utils.tensorboard import SummaryWriter from visdom import Visdom from configs import cfg, update_config from dataset.multi_label.coco import COCO14 from dataset.augmentation import get_transform from metrics.ml_metrics import get_map_metrics, get_multilabel_metrics from metrics.pedestrian_metrics import get_pedestrian_metrics from models.model_ema import ModelEmaV2 from optim.adamw import AdamW from scheduler.cos_annealing_with_restart import CosineAnnealingLR_with_Restart from scheduler.cosine_lr import CosineLRScheduler from tools.distributed import distribute_bn from tools.vis import tb_visualizer_pedes import torch from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR from torch.utils.data import DataLoader from batch_engine import valid_trainer, batch_trainer from dataset.pedes_attr.pedes import PedesAttr from models.base_block import FeatClassifier from models.model_factory import build_loss, build_classifier, build_backbone from tools.function import get_model_log_path, get_reload_weight, seperate_weight_decay from tools.utils import time_str, save_ckpt, ReDirectSTD, set_seed, str2bool from models.backbone import swin_transformer, resnet, bninception, vit from models.backbone.tresnet import tresnet from losses import bceloss, scaledbceloss from models import base_block # torch.backends.cudnn.benchmark = True # torch.autograd.set_detect_anomaly(True) torch.autograd.set_detect_anomaly(True) def main(cfg, args): set_seed(605) exp_dir = os.path.join('exp_result', cfg.DATASET.NAME) model_dir, log_dir = get_model_log_path(exp_dir, cfg.NAME) stdout_file = os.path.join(log_dir, f'stdout_{time_str()}.txt') save_model_path = os.path.join(model_dir, f'ckpt_max_{time_str()}.pth') visdom = None if cfg.VIS.VISDOM: visdom = Visdom(env=f'{cfg.DATASET.NAME}_' + cfg.NAME, port=8401) assert visdom.check_connection() writer = None if cfg.VIS.TENSORBOARD.ENABLE: current_time = datetime.now().strftime('%b%d_%H-%M-%S') writer_dir = os.path.join(exp_dir, cfg.NAME, 'runs', current_time) writer = SummaryWriter(log_dir=writer_dir) if cfg.REDIRECTOR: print('redirector stdout') ReDirectSTD(stdout_file, 'stdout', False) """ the reason for args usage is CfgNode is immutable """ if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 else: args.distributed = None args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() print(f'use GPU{args.device} for training') print(args.world_size, args.rank) if args.local_rank == 0: print(cfg) train_tsfm, valid_tsfm = get_transform(cfg) if args.local_rank == 0: print(train_tsfm) if cfg.DATASET.TYPE == 'pedes': train_set = PedesAttr(cfg=cfg, split=cfg.DATASET.TRAIN_SPLIT, transform=train_tsfm, target_transform=cfg.DATASET.TARGETTRANSFORM) valid_set = PedesAttr(cfg=cfg, split=cfg.DATASET.VAL_SPLIT, transform=valid_tsfm, target_transform=cfg.DATASET.TARGETTRANSFORM) elif cfg.DATASET.TYPE == 'multi_label': train_set = COCO14(cfg=cfg, split=cfg.DATASET.TRAIN_SPLIT, transform=train_tsfm, target_transform=cfg.DATASET.TARGETTRANSFORM) valid_set = COCO14(cfg=cfg, split=cfg.DATASET.VAL_SPLIT, transform=valid_tsfm, target_transform=cfg.DATASET.TARGETTRANSFORM) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) else: train_sampler = None train_loader = DataLoader( dataset=train_set, batch_size=cfg.TRAIN.BATCH_SIZE, sampler=train_sampler, shuffle=train_sampler is None, num_workers=4, pin_memory=True, drop_last=True, ) valid_loader = DataLoader( dataset=valid_set, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, ) if args.local_rank == 0: print('-' * 60) print(f'{cfg.DATASET.NAME} attr_num : {train_set.attr_num}, eval_attr_num : {train_set.eval_attr_num} ' f'{cfg.DATASET.TRAIN_SPLIT} set: {len(train_loader.dataset)}, ' f'{cfg.DATASET.TEST_SPLIT} set: {len(valid_loader.dataset)}, ' ) labels = train_set.label label_ratio = labels.mean(0) if cfg.LOSS.SAMPLE_WEIGHT else None backbone, c_output = build_backbone(cfg.BACKBONE.TYPE, cfg.BACKBONE.MULTISCALE) classifier = build_classifier(cfg.CLASSIFIER.NAME)( nattr=train_set.attr_num, c_in=c_output, bn=cfg.CLASSIFIER.BN, pool=cfg.CLASSIFIER.POOLING, scale =cfg.CLASSIFIER.SCALE ) model = FeatClassifier(backbone, classifier, bn_wd=cfg.TRAIN.BN_WD) if args.local_rank == 0: print(f"backbone: {cfg.BACKBONE.TYPE}, classifier: {cfg.CLASSIFIER.NAME}") print(f"model_name: {cfg.NAME}") # flops, params = get_model_complexity_info(model, (3, 256, 128), print_per_layer_stat=True) # print('{:<30} {:<8}'.format('Computational complexity: ', flops)) # print('{:<30} {:<8}'.format('Number of parameters: ', params)) model = model.cuda() if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) else: model = torch.nn.DataParallel(model) model_ema = None if cfg.TRAIN.EMA.ENABLE: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=cfg.TRAIN.EMA.DECAY, device='cpu' if cfg.TRAIN.EMA.FORCE_CPU else None) if cfg.RELOAD.TYPE: model = get_reload_weight(model_dir, model, pth=cfg.RELOAD.PTH) loss_weight = cfg.LOSS.LOSS_WEIGHT criterion = build_loss(cfg.LOSS.TYPE)( sample_weight=label_ratio, scale=cfg.CLASSIFIER.SCALE, size_sum=cfg.LOSS.SIZESUM, tb_writer=writer) criterion = criterion.cuda() if cfg.TRAIN.BN_WD: param_groups = [{'params': model.module.finetune_params(), 'lr': cfg.TRAIN.LR_SCHEDULER.LR_FT, 'weight_decay': cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY}, {'params': model.module.fresh_params(), 'lr': cfg.TRAIN.LR_SCHEDULER.LR_NEW, 'weight_decay': cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY}] else: # bn parameters are not applied with weight decay ft_params = seperate_weight_decay( model.module.finetune_params(), lr=cfg.TRAIN.LR_SCHEDULER.LR_FT, weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY) fresh_params = seperate_weight_decay( model.module.fresh_params(), lr=cfg.TRAIN.LR_SCHEDULER.LR_NEW, weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY) param_groups = ft_params + fresh_params if cfg.TRAIN.OPTIMIZER.TYPE.lower() == 'sgd': optimizer = torch.optim.SGD(param_groups, momentum=cfg.TRAIN.OPTIMIZER.MOMENTUM) elif cfg.TRAIN.OPTIMIZER.TYPE.lower() == 'adam': optimizer = torch.optim.Adam(param_groups) elif cfg.TRAIN.OPTIMIZER.TYPE.lower() == 'adamw': optimizer = AdamW(param_groups) else: assert None, f'{cfg.TRAIN.OPTIMIZER.TYPE} is not implemented' if cfg.TRAIN.LR_SCHEDULER.TYPE == 'plateau': lr_scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=4) if cfg.CLASSIFIER.BN: assert False, 'BN can not compatible with ReduceLROnPlateau' elif cfg.TRAIN.LR_SCHEDULER.TYPE == 'multistep': lr_scheduler = MultiStepLR(optimizer, milestones=cfg.TRAIN.LR_SCHEDULER.LR_STEP, gamma=0.1) elif cfg.TRAIN.LR_SCHEDULER.TYPE == 'annealing_cosine': lr_scheduler = CosineAnnealingLR_with_Restart( optimizer, T_max=(cfg.TRAIN.MAX_EPOCH + 5) * len(train_loader), T_mult=1, eta_min=cfg.TRAIN.LR_SCHEDULER.LR_NEW * 0.001 ) elif cfg.TRAIN.LR_SCHEDULER.TYPE == 'warmup_cosine': lr_scheduler = CosineLRScheduler( optimizer, t_initial=cfg.TRAIN.MAX_EPOCH, lr_min=1e-5, # cosine lr 最终回落的位置 warmup_lr_init=1e-4, warmup_t=cfg.TRAIN.MAX_EPOCH * cfg.TRAIN.LR_SCHEDULER.WMUP_COEF, ) else: assert False, f'{cfg.LR_SCHEDULER.TYPE} has not been achieved yet' best_metric, epoch = trainer(cfg, args, epoch=cfg.TRAIN.MAX_EPOCH, model=model, model_ema=model_ema, train_loader=train_loader, valid_loader=valid_loader, criterion=criterion, optimizer=optimizer, lr_scheduler=lr_scheduler, path=save_model_path, loss_w=loss_weight, viz=visdom, tb_writer=writer) if args.local_rank == 0: print(f'{cfg.NAME}, best_metrc : {best_metric} in epoch{epoch}') def trainer(cfg, args, epoch, model, model_ema, train_loader, valid_loader, criterion, optimizer, lr_scheduler, path, loss_w, viz, tb_writer): maximum = float(-np.inf) best_epoch = 0 result_list = defaultdict() result_path = path result_path = result_path.replace('ckpt_max', 'metric') result_path = result_path.replace('pth', 'pkl') for e in range(epoch): if args.distributed: train_loader.sampler.set_epoch(epoch) lr = optimizer.param_groups[1]['lr'] train_loss, train_gt, train_probs, train_imgs, train_logits, train_loss_mtr = batch_trainer( cfg, args=args, epoch=e, model=model, model_ema=model_ema, train_loader=train_loader, criterion=criterion, optimizer=optimizer, loss_w=loss_w, scheduler=lr_scheduler if cfg.TRAIN.LR_SCHEDULER.TYPE == 'annealing_cosine' else None, ) if args.distributed: if args.local_rank == 0: print("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') if model_ema is not None and not cfg.TRAIN.EMA.FORCE_CPU: if args.local_rank == 0: print('using model_ema to validate') if args.distributed: distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') valid_loss, valid_gt, valid_probs, valid_imgs, valid_logits, valid_loss_mtr = valid_trainer( cfg, args=args, epoch=e, model=model_ema.module, valid_loader=valid_loader, criterion=criterion, loss_w=loss_w ) else: valid_loss, valid_gt, valid_probs, valid_imgs, valid_logits, valid_loss_mtr = valid_trainer( cfg, args=args, epoch=e, model=model, valid_loader=valid_loader, criterion=criterion, loss_w=loss_w ) if cfg.TRAIN.LR_SCHEDULER.TYPE == 'plateau': lr_scheduler.step(metrics=valid_loss) elif cfg.TRAIN.LR_SCHEDULER.TYPE == 'warmup_cosine': lr_scheduler.step(epoch=e + 1) elif cfg.TRAIN.LR_SCHEDULER.TYPE == 'multistep': lr_scheduler.step() if cfg.METRIC.TYPE == 'pedestrian': train_result = get_pedestrian_metrics(train_gt, train_probs, index=None, cfg=cfg) valid_result = get_pedestrian_metrics(valid_gt, valid_probs, index=None, cfg=cfg) if args.local_rank == 0: print(f'Evaluation on train set, train losses {train_loss}\n', 'ma: {:.4f}, label_f1: {:.4f}, pos_recall: {:.4f} , neg_recall: {:.4f} \n'.format( train_result.ma, np.mean(train_result.label_f1), np.mean(train_result.label_pos_recall), np.mean(train_result.label_neg_recall)), 'Acc: {:.4f}, Prec: {:.4f}, Rec: {:.4f}, F1: {:.4f}'.format( train_result.instance_acc, train_result.instance_prec, train_result.instance_recall, train_result.instance_f1)) print(f'Evaluation on test set, valid losses {valid_loss}\n', 'ma: {:.4f}, label_f1: {:.4f}, pos_recall: {:.4f} , neg_recall: {:.4f} \n'.format( valid_result.ma, np.mean(valid_result.label_f1), np.mean(valid_result.label_pos_recall), np.mean(valid_result.label_neg_recall)), 'Acc: {:.4f}, Prec: {:.4f}, Rec: {:.4f}, F1: {:.4f}'.format( valid_result.instance_acc, valid_result.instance_prec, valid_result.instance_recall, valid_result.instance_f1)) print(f'{time_str()}') print('-' * 60) if args.local_rank == 0: tb_visualizer_pedes(tb_writer, lr, e, train_loss, valid_loss, train_result, valid_result, train_gt, valid_gt, train_loss_mtr, valid_loss_mtr, model, train_loader.dataset.attr_id) cur_metric = valid_result.ma if cur_metric > maximum: maximum = cur_metric best_epoch = e save_ckpt(model, path, e, maximum) result_list[e] = { 'train_result': train_result, # 'train_map': train_map, 'valid_result': valid_result, # 'valid_map': valid_map, 'train_gt': train_gt, 'train_probs': train_probs, 'valid_gt': valid_gt, 'valid_probs': valid_probs, 'train_imgs': train_imgs, 'valid_imgs': valid_imgs } elif cfg.METRIC.TYPE == 'multi_label': train_metric = get_multilabel_metrics(train_gt, train_probs) valid_metric = get_multilabel_metrics(valid_gt, valid_probs) if args.local_rank == 0: print( 'Train Performance : mAP: {:.4f}, OP: {:.4f}, OR: {:.4f}, OF1: {:.4f} CP: {:.4f}, CR: {:.4f}, ' 'CF1: {:.4f}'.format(train_metric.map, train_metric.OP, train_metric.OR, train_metric.OF1, train_metric.CP, train_metric.CR, train_metric.CF1)) print( 'Test Performance : mAP: {:.4f}, OP: {:.4f}, OR: {:.4f}, OF1: {:.4f} CP: {:.4f}, CR: {:.4f}, ' 'CF1: {:.4f}'.format(valid_metric.map, valid_metric.OP, valid_metric.OR, valid_metric.OF1, valid_metric.CP, valid_metric.CR, valid_metric.CF1)) print(f'{time_str()}') print('-' * 60) tb_writer.add_scalars('train/lr', {'lr': lr}, e) tb_writer.add_scalars('train/losses', {'train': train_loss, 'test': valid_loss}, e) tb_writer.add_scalars('train/perf', {'mAP': train_metric.map, 'OP': train_metric.OP, 'OR': train_metric.OR, 'OF1': train_metric.OF1, 'CP': train_metric.CP, 'CR': train_metric.CR, 'CF1': train_metric.CF1}, e) tb_writer.add_scalars('test/perf', {'mAP': valid_metric.map, 'OP': valid_metric.OP, 'OR': valid_metric.OR, 'OF1': valid_metric.OF1, 'CP': valid_metric.CP, 'CR': valid_metric.CR, 'CF1': valid_metric.CF1}, e) cur_metric = valid_metric.map if cur_metric > maximum: maximum = cur_metric best_epoch = e save_ckpt(model, path, e, maximum) result_list[e] = { 'train_result': train_metric, 'valid_result': valid_metric, 'train_gt': train_gt, 'train_probs': train_probs, 'valid_gt': valid_gt, 'valid_probs': valid_probs } else: assert False, f'{cfg.METRIC.TYPE} is unavailable' with open(result_path, 'wb') as f: pickle.dump(result_list, f) return maximum, best_epoch def argument_parser(): parser = argparse.ArgumentParser(description="attribute recognition", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--cfg", help="decide which cfg to use", type=str, default="./configs/pedes_baseline/pa100k.yaml", ) parser.add_argument("--debug", type=str2bool, default="true") parser.add_argument('--local_rank', help='node rank for distributed training', default=0, type=int) parser.add_argument('--dist_bn', type=str, default='', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') args = parser.parse_args() return args if __name__ == '__main__': args = argument_parser() update_config(cfg, args) main(cfg, args) “

class FeatureProcess: def __init__(self, logger, usr_conf=None): self.logger = logger self.usr_conf = usr_conf or {} # 接收并保存配置 self.reset() def reset(self): # 1. 静态道路信息 self.junction_dict = {} self.edge_dict = {} self.lane_dict = {} self.vehicle_configs = {} self.l_id_to_index = {} self.lane_to_junction = {} self.phase_lane_mapping = {} # 结构 # 2. 动态帧信息 self.vehicle_status = {} self.lane_volume = {} self.lane_congestion = {} self.current_phases = {} self.lane_demand = {} # :车道需求缓存 # 3. 车辆历史轨迹数据 self.vehicle_prev_junction = {} self.vehicle_prev_position = {} self.vehicle_distance_store = {} self.last_waiting_moment = {} self.waiting_time_store = {} self.enter_lane_time = {} self.vehicle_enter_time = {} # 车辆进入系统时间 self.vehicle_ideal_time = {} # 理想行驶时间 self.current_vehicles = {} # 缓存当前帧车辆 self.vehicle_trajectory = {} # 车辆轨迹缓存 # 4. 场景状态 self.peak_hour = False self.weather = 0 self.accident_lanes = set() self.control_lanes = set() self.accident_configs = [] self.control_configs = [] # 5. 路口级指标累计 self.junction_metrics = {} self.enter_lane_ids = set() # 区域信息 self.region_dict = {} # 区域ID -> 包含的路口ID列表 self.region_capacity = {} # 区域ID -> 区域容量 self.junction_region_map = {} # 路口ID -> 区域ID def init_road_info(self, start_info): """初始化道路信息(:解析信号-相位-车道映射)""" junctions, signals, edges = ( start_info["junctions"], start_info["signals"], # 信号数据(含相位-车道关系) start_info["edges"], ) lane_configs, vehicle_configs = ( start_info["lane_configs"], start_info["vehicle_configs"], ) # 1:先建立车道-路口映射 self.enter_lane_ids = set() # 重置进口车道集合 for junction in junctions: j_id = junction["j_id"] self.junction_dict[j_id] = junction self.l_id_to_index[j_id] = {} # 收集所有进口车道ID for approaching_edges in junction["enter_lanes_on_directions"]: for lane in approaching_edges["lanes"]: self.enter_lane_ids.add(lane) self.lane_to_junction[lane] = j_id # 初始化指标累计 self.junction_metrics[j_id] = { "total_delay": 0.0, "total_vehicles": 0, "total_waiting": 0.0, "total_queue": 0, "queue_count": 0, "counted_vehicles": set(), "completed_vehicles": set() } # 路口-车道索引映射 index = 0 for approaching_edges in junction["enter_lanes_on_directions"]: for lane in approaching_edges["lanes"]: self.l_id_to_index[j_id][lane] = index index += 1 # 2:解析信号-相位-车道映射 self.phase_lane_mapping.clear() for signal in signals: s_id = signal["s_id"] # 信号ID(对应路口的信号灯) self.phase_lane_mapping[s_id] = {} # 初始化该信号的相位映射 # 遍历信号下的所有相位 for phase in signal.get("phases", []): phase_id = phase.get("phase_id") # 相位ID phase_lanes = phase.get("lanes", []) # 该相位对应的车道列表 if phase_id is not None and phase_lanes: self.phase_lane_mapping[s_id][phase_id] = phase_lanes else: self.logger.warning(f"信号 {s_id} 的相位 {phase_id} 缺失车道数据,跳过") # 3:全局道路信息 for edge in edges: self.edge_dict[edge["e_id"]] = edge for lane in lane_configs: l_id = lane["l_id"] self.lane_dict[l_id] = lane self.lane_volume[l_id] = [] self.lane_congestion[l_id] = 0 # 4:车辆配置 for cfg in vehicle_configs: self.vehicle_configs[cfg["v_config_id"]] = { "v_type": cfg["v_type"], "max_speed": cfg["max_speed"] } # 5:区域信息初始化 regions = start_info.get("regions", []) for region in regions: r_id = region["r_id"] self.region_dict[r_id] = region["junction_ids"] self.region_capacity[r_id] = 0 for j_id in region["junction_ids"]: self.junction_region_map[j_id] = r_id # 累加区域容量(进口车道数×10) junction = self.junction_dict.get(j_id, {}) enter_lanes = [lane for dirs in junction.get("enter_lanes_on_directions", []) for lane in dirs.get("lanes", [])] self.region_capacity[r_id] += len(enter_lanes) * 10 def get_region(self, j_id): """获取路口所属区域ID""" return self.junction_region_map.get(j_id, -1) def get_region_capacity(self, region_id): """获取区域容量""" return self.region_capacity.get(region_id, 0) def get_junction_capacity(self, j_id): """获取路口容量(基于进口车道数)""" junction = self.junction_dict.get(j_id, {}) if not junction: return 20 enter_lanes = [] for dirs in junction.get("enter_lanes_on_directions", []): enter_lanes.extend(dirs.get("lanes", [])) # 根据车道类型分配不同容量 capacity = 0 for lane_id in enter_lanes: lane_type = self.lane_dict.get(lane_id, {}).get("turn_type", 0) # 左转和直行车道容量更高 if lane_type in [0, 1]: # 直行/左转 capacity += 15 else: # 右转 capacity += 10 return capacity def get_lane_capacity(self, lane_id): """获取单个车道容量""" lane_type = self.lane_dict.get(lane_id, {}).get("turn_type", 0) return 15 if lane_type in [0, 1] else 10 def get_region_avg_queue(self, region_id): """获取区域平均排队长度""" if region_id not in self.region_dict: return 0.0 total_queue = 0 count = 0 for j_id in self.region_dict[region_id]: queue = self.get_junction_avg_queue(j_id) if queue > 0: total_queue += queue count += 1 return total_queue / count if count > 0 else 0.0 def get_region_congestion(self, region_id): """获取区域拥堵指数(0-1)""" total_queue = 0 total_capacity = self.get_region_capacity(region_id) if total_capacity <= 0: return 0.0 for j_id in self.region_dict.get(region_id, []): total_queue += self.get_junction_avg_queue(j_id) return min(1.0, total_queue / total_capacity) # 车道需求计算 def calculate_lane_demand(self): """计算车道需求(排队长度+接近车辆)""" self.lane_demand = {} for lane_id, vehicles in self.lane_volume.items(): # 基础需求 = 当前排队车辆数 demand = len(vehicles) # 增加接近车辆权重(距离<100米) for v_id in vehicles: vehicle = self.current_vehicles.get(v_id) if vehicle and vehicle.get("position_in_lane", {}).get("y", 0) < 100: demand += 0.5 # 接近车辆增加部分需求 self.lane_demand[lane_id] = demand def update_traffic_info(self, obs, extra_info): """更新动态信息(优化延误计算和内存管理)""" frame_state = obs["framestate"] frame_no = frame_state["frame_no"] frame_time_ms = frame_state["frame_time"] vehicles = frame_state["vehicles"] phases = frame_state["phases"] lanes = frame_state["lanes"] frame_time = frame_time_ms / 1000.0 current_v_ids = {v["v_id"] for v in vehicles} self.current_vehicles = {v["v_id"]: v for v in vehicles} # 缓存当前车辆 self.calculate_lane_demand() # 首次帧初始化 if frame_no == 1: game_info = extra_info["gameinfo"] self.init_road_info(game_info) self._init_scene_config() if frame_no % 100 == 0: self.logger.info( f"【帧 {frame_no} 动态场景状态】" f"天气: {self.get_weather()}, 高峰期: {self.is_peak_hour()}, " f"生效事故车道: {self.accident_lanes}, 生效管制车道: {self.control_lanes}, " f"已加载相位-车道映射的信号数: {len(self.phase_lane_mapping)}" # 日志:验证映射是否加载 ) # 更新无效车道 self._update_active_invalid_lanes(frame_no) # 1. 车道-车辆映射 self.lane_volume = {l_id: [] for l_id in self.lane_dict} lane_vehicle_map = defaultdict(list) for vehicle in vehicles: lane_id = vehicle["lane"] lane_vehicle_map[lane_id].append(vehicle["v_id"]) for lane_id, v_ids in lane_vehicle_map.items(): if lane_id in self.lane_volume: self.lane_volume[lane_id] = v_ids # 更新拥堵等级 for lane in lanes: l_id = lane["lane_id"] if l_id in self.lane_congestion: self.lane_congestion[l_id] = lane["congestion"] # 2. 更新相位 self.current_phases.clear() for phase in phases: self.current_phases[phase["s_id"]] = { "remaining_duration": phase["remaining_duration"], "phase_id": phase["phase_id"] } # 3. 逐车更新 for vehicle in vehicles: v_id = vehicle["v_id"] lane_id = vehicle["lane"] v_config_id = vehicle["v_config_id"] target_j_id = vehicle["target_junction"] current_j_id = vehicle["junction"] # 车辆状态标记 if lane_id in self.accident_lanes or lane_id in self.control_lanes: self.vehicle_status[v_id] = 1 else: self.vehicle_status[v_id] = vehicle["v_status"] # 初始化车辆进入时间 if v_id not in self.vehicle_enter_time: self.vehicle_enter_time[v_id] = frame_time max_speed = self.vehicle_configs.get(v_config_id, {}).get("max_speed", 10) # 计算理想时间 lane_length = self.lane_dict.get(lane_id, {}).get("length", 100) ideal_time = lane_length / max_speed self.vehicle_ideal_time[v_id] = ideal_time # 初始化历史路口 if v_id not in self.vehicle_prev_junction: self.vehicle_prev_junction[v_id] = current_j_id # 处理正常车辆 if self.vehicle_status[v_id] == 0: # 更新进入进口道时间 if (self.vehicle_prev_junction[v_id] == -1 and on_enter_lane(vehicle) and v_id not in self.enter_lane_time): self.enter_lane_time[v_id] = frame_time elif (self.vehicle_prev_junction[v_id] != current_j_id and self.vehicle_prev_junction[v_id] != -1 and on_enter_lane(vehicle)): self.enter_lane_time[v_id] = frame_time # 计算等待时间/行驶距离 self.cal_waiting_time(frame_time, vehicle) self.cal_travel_distance(vehicle) # 判断车辆是否已通过目标路口 if (target_j_id != -1 and target_j_id in self.junction_metrics and v_id in self.junction_metrics[target_j_id]["completed_vehicles"]): continue # 标记"已通过目标路口"的车辆 if (target_j_id != -1 and target_j_id in self.junction_metrics and current_j_id == target_j_id and not on_enter_lane(vehicle)): self.junction_metrics[target_j_id]["completed_vehicles"].add(v_id) continue # 计算实时延误 if target_j_id != -1 and target_j_id in self.junction_metrics: actual_time = frame_time - self.vehicle_enter_time[v_id] current_dist = self.vehicle_distance_store.get(v_id, 0.0) max_speed = self.vehicle_configs.get(v_config_id, {}).get("max_speed", 10) # 使用更准确的车道长度 lane_length = self.lane_dict.get(lane_id, {}).get("length", 100) remaining_dist = max(0, lane_length - vehicle["position_in_lane"]["y"]) total_ideal_dist = current_dist + remaining_dist updated_ideal_time = total_ideal_dist / max_speed delay = max(0.0, actual_time - updated_ideal_time) # 累计延误 if v_id not in self.junction_metrics[target_j_id]["counted_vehicles"]: self.junction_metrics[target_j_id]["total_delay"] += delay self.junction_metrics[target_j_id]["total_vehicles"] += 1 self.junction_metrics[target_j_id]["counted_vehicles"].add(v_id) # 累计等待时间 waiting_time = self.waiting_time_store.get(v_id, 0.0) self.junction_metrics[target_j_id]["total_waiting"] += waiting_time # 更新历史路口 self.vehicle_prev_junction[v_id] = current_j_id # 4. 累计排队长度 for j_id in self.junction_dict: current_queue = self.calculate_junction_queue(j_id, vehicles, current_v_ids) self.junction_metrics[j_id]["total_queue"] += current_queue self.junction_metrics[j_id]["queue_count"] += 1 # 清理已离开系统的车辆历史数据 self._clean_expired_vehicle_data(current_v_ids) # 初始化车辆轨迹(若未初始化) if not hasattr(self, 'vehicle_trajectory'): self.vehicle_trajectory = {} for v_id, vehicle in self.current_vehicles.items(): if v_id not in self.vehicle_trajectory: self.vehicle_trajectory[v_id] = { "entered_junction": False, "distance": 0.0, "last_position": vehicle.get("position_in_lane", {}) } else: # 更新行驶距离 last_pos = self.vehicle_trajectory[v_id]["last_position"] curr_pos = vehicle.get("position_in_lane", {}) if "x" in last_pos and "y" in last_pos and "x" in curr_pos and "y" in curr_pos: dx = curr_pos["x"] - last_pos["x"] dy = curr_pos["y"] - last_pos["y"] self.vehicle_trajectory[v_id]["distance"] += math.hypot(dx, dy) self.vehicle_trajectory[v_id]["last_position"] = curr_pos # 标记是否进入路口 if not self.vehicle_trajectory[v_id]["entered_junction"]: if vehicle.get("junction", -1) != -1: self.vehicle_trajectory[v_id]["entered_junction"] = True # 路口车辆计数方法 def get_junction_volume(self, j_id): """获取路口当前交通量(进口车道车辆总数)""" count = 0 junction = self.junction_dict.get(j_id, {}) for dirs in junction.get("enter_lanes_on_directions", []): for lane_id in dirs.get("lanes", []): count += len(self.lane_volume.get(lane_id, [])) return count def calculate_junction_queue(self, j_id, vehicles, current_v_ids): """计算排队长度(使用更准确的车道判断)""" junction = self.junction_dict[j_id] invalid_lanes = self.get_invalid_lanes() enter_lanes = [lane for dirs in junction["enter_lanes_on_directions"] for lane in dirs["lanes"]] valid_lanes = [l for l in enter_lanes if l not in invalid_lanes] queue = 0 vehicle_dict = {v["v_id"]: v for v in vehicles} for lane_id in valid_lanes: for v_id in self.lane_volume.get(lane_id, []): if v_id not in current_v_ids or self.vehicle_status.get(v_id, 0) != 0: continue vehicle = vehicle_dict.get(v_id, {}) if vehicle.get("speed", 0) <= 1.0: queue += 1 return queue def _clean_expired_vehicle_data(self, current_v_ids): # 计算所有需要删除的车辆ID all_v_ids = set( self.vehicle_prev_position.keys() | self.vehicle_enter_time.keys() | self.waiting_time_store.keys() | self.vehicle_prev_junction.keys() | self.vehicle_distance_store.keys() ) expired_v_ids = all_v_ids - current_v_ids # 批量删除 for v_id in expired_v_ids: self.vehicle_prev_position.pop(v_id, None) self.vehicle_distance_store.pop(v_id, None) self.vehicle_enter_time.pop(v_id, None) self.vehicle_ideal_time.pop(v_id, None) self.waiting_time_store.pop(v_id, None) self.last_waiting_moment.pop(v_id, None) self.vehicle_prev_junction.pop(v_id, None) self.enter_lane_time.pop(v_id, None) self.vehicle_trajectory.pop(v_id, None) # 清理路口指标中的过期车辆 for j_metrics in self.junction_metrics.values(): j_metrics["counted_vehicles"] &= current_v_ids j_metrics["completed_vehicles"] &= current_v_ids # 路口级指标查询接口 def get_junction_avg_delay(self, j_id): metrics = self.junction_metrics.get(j_id, {}) total_delay = metrics.get("total_delay", 0.0) total_vehicles = metrics.get("total_vehicles", 0) return total_delay / total_vehicles if total_vehicles > 0 else 0.0 def get_all_avg_delay(self): total_delay = sum(metrics["total_delay"] for metrics in self.junction_metrics.values()) total_vehicles = sum(metrics["total_vehicles"] for metrics in self.junction_metrics.values()) return total_delay / total_vehicles if total_vehicles > 0 else 0.0 def get_junction_avg_waiting(self, j_id): metrics = self.junction_metrics.get(j_id, {}) total_waiting = metrics.get("total_waiting", 0.0) total_vehicles = metrics.get("total_vehicles", 0) return total_waiting / total_vehicles if total_vehicles > 0 else 0.0 def get_all_avg_waiting(self): total_waiting = sum(metrics["total_waiting"] for metrics in self.junction_metrics.values()) total_vehicles = sum(metrics["total_vehicles"] for metrics in self.junction_metrics.values()) return total_waiting / total_vehicles if total_vehicles > 0 else 0.0 def get_junction_avg_queue(self, j_id): metrics = self.junction_metrics.get(j_id, {}) total_queue = metrics.get("total_queue", 0) queue_count = metrics.get("queue_count", 0) return total_queue / queue_count if queue_count > 0 else 0.0 def get_all_avg_queue(self): total_queue = sum(metrics["total_queue"] for metrics in self.junction_metrics.values()) total_count = sum(metrics["queue_count"] for metrics in self.junction_metrics.values()) return total_queue / total_count if total_count > 0 else 0.0 def cal_waiting_time(self, frame_time, vehicle): v_id = vehicle["v_id"] if on_enter_lane(vehicle, self.get_invalid_lanes()): distance_to_stop = vehicle["position_in_lane"]["y"] if vehicle["speed"] <= 0.1 and distance_to_stop < 50: if v_id not in self.last_waiting_moment: self.last_waiting_moment[v_id] = frame_time self.waiting_time_store.setdefault(v_id, 0.0) else: waiting_duration = frame_time - self.last_waiting_moment[v_id] self.waiting_time_store[v_id] += waiting_duration self.last_waiting_moment[v_id] = frame_time else: if v_id in self.last_waiting_moment: del self.last_waiting_moment[v_id] else: if v_id in self.waiting_time_store: del self.waiting_time_store[v_id] if v_id in self.last_waiting_moment: del self.last_waiting_moment[v_id] def cal_travel_distance(self, vehicle): v_id = vehicle["v_id"] if on_enter_lane(vehicle, self.get_invalid_lanes()): # 若车辆从其他路口进入当前进口道,重置距离记录 if self.vehicle_prev_junction[v_id] != -1 and v_id in self.vehicle_distance_store: del self.vehicle_distance_store[v_id] self.vehicle_distance_store[v_id] = 0.0 # 初始化历史位置(若未初始化) if v_id not in self.vehicle_prev_position: current_pos = vehicle.get("position_in_lane", {}) # 补充位置数据校验(避免无效数据) if "x" in current_pos and "y" in current_pos and isinstance(current_pos["x"], (int, float)) and isinstance(current_pos["y"], (int, float)): self.vehicle_prev_position[v_id] = { "x": current_pos["x"], "y": current_pos["y"], "distance_to_stop": current_pos["y"] } # 若距离未初始化,补充初始化 if v_id not in self.vehicle_distance_store: self.vehicle_distance_store[v_id] = 0.0 else: self.logger.warning(f"车辆 {v_id} 位置数据无效,无法初始化历史位置") return prev_pos = self.vehicle_prev_position[v_id] current_pos = vehicle["position_in_lane"] try: dx = current_pos["x"] - prev_pos["x"] dy = current_pos["y"] - prev_pos["y"] euclid_distance = math.hypot(dx, dy) stop_distance_reduce = prev_pos["distance_to_stop"] - current_pos["y"] # 确保距离记录存在(双重保险) if v_id not in self.vehicle_distance_store: self.vehicle_distance_store[v_id] = 0.0 self.vehicle_distance_store[v_id] += max(euclid_distance, stop_distance_reduce) except Exception as e: self.logger.error( f"计算行驶距离失败 (v_id={v_id}): " f"异常类型={type(e).__name__}, 异常信息={str(e)}, " f"position_in_lane={current_pos}, prev_pos={prev_pos}" ) else: # 车辆离开进口道,清理历史数据 if v_id in self.vehicle_prev_position: del self.vehicle_prev_position[v_id] if v_id in self.vehicle_distance_store: del self.vehicle_distance_store[v_id] def get_all_junction_waiting_time(self, vehicles): res = {j_id: 0.0 for j_id in self.junction_dict} v_num = {j_id: 0 for j_id in self.junction_dict} vehicle_type_weight = {1: 1.0, 2: 1.5, 3: 1.2, 4: 0.8, 5: 0.5} for vehicle in vehicles: v_id = vehicle["v_id"] if (self.vehicle_status.get(v_id, 0) != 0 or vehicle["junction"] != -1 or vehicle["target_junction"] == -1): continue j_id = vehicle["target_junction"] v_type = self.vehicle_configs[vehicle["v_config_id"]]["v_type"] weight = vehicle_type_weight.get(v_type, 1.0) res[j_id] += self.waiting_time_store.get(v_id, 0.0) * weight v_num[j_id] += 1 for j_id in self.junction_dict: if v_num[j_id] > 0: res[j_id] /= v_num[j_id] return res def get_phase_remaining_time(self, signal_id): return self.current_phases.get(signal_id, {}).get("remaining_duration", 0) def _init_scene_config(self): # 1. 读取配置 self.weather = self.usr_conf.get("weather", 0) self.peak_hour = self.usr_conf.get("rush_hour", 0) == 1 self.accident_configs = self.usr_conf.get("traffic_accidents", {}).get("custom_configuration", []) self.control_configs = self.usr_conf.get("traffic_control", {}).get("custom_configuration", []) # 2. 配置校验 valid_accident_configs = [] for idx, acc in enumerate(self.accident_configs): # 校验关键字段存在性+类型+逻辑 if not all(k in acc for k in ["lane_index", "start_time", "end_time"]): self.logger.warning(f"事故规则 {idx} 缺失关键字段(lane_index/start_time/end_time),跳过") continue if not isinstance(acc["lane_index"], int) or acc["lane_index"] < 0: self.logger.warning(f"事故规则 {idx} lane_index={acc['lane_index']} 无效(需非负整数),跳过") continue if not (isinstance(acc["start_time"], int) and isinstance(acc["end_time"], int)) or acc["start_time"] > acc["end_time"]: self.logger.warning(f"事故规则 {idx} 时间范围无效(start={acc['start_time']}, end={acc['end_time']}),跳过") continue valid_accident_configs.append(acc) self.accident_configs = valid_accident_configs # 同理校验管制规则 valid_control_configs = [] for idx, ctrl in enumerate(self.control_configs): if not all(k in ctrl for k in ["lane_index", "start_time", "end_time"]): self.logger.warning(f"管制规则 {idx} 缺失关键字段,跳过") continue if not isinstance(ctrl["lane_index"], int) or ctrl["lane_index"] < 0: self.logger.warning(f"管制规则 {idx} lane_index={ctrl['lane_index']} 无效,跳过") continue if not (isinstance(ctrl["start_time"], int) and isinstance(ctrl["end_time"], int)) or ctrl["start_time"] > ctrl["end_time"]: self.logger.warning(f"管制规则 {idx} 时间范围无效,跳过") continue valid_control_configs.append(ctrl) self.control_configs = valid_control_configs self.logger.info("【FeatureProcess 配置解析结果】") self.logger.info(f"- 天气: {self.weather} (0=晴/1=雨/2=雪/3=雾)") self.logger.info(f"- 高峰期: {'是' if self.peak_hour else '否'}") self.logger.info(f"- 有效事故规则数: {len(self.accident_configs)}(原始: {len(self.usr_conf.get('traffic_accidents', {}).get('custom_configuration', []))})") self.logger.info(f"- 有效管制规则数: {len(self.control_configs)}(原始: {len(self.usr_conf.get('traffic_control', {}).get('custom_configuration', []))})") def _update_active_invalid_lanes(self, frame_no): self.accident_lanes.clear() for acc in self.accident_configs: if acc.get("start_time", 0) <= frame_no <= acc.get("end_time", 0): lane_idx = acc.get("lane_index", -1) if lane_idx in self.lane_dict: self.accident_lanes.add(lane_idx) else: self.logger.warning(f"事故规则 lane_index={lane_idx} 不存在,跳过(当前有效车道:{list(self.lane_dict.keys())})") self.control_lanes.clear() for ctrl in self.control_configs: if ctrl.get("start_time", 0) <= frame_no <= ctrl.get("end_time", 0): lane_idx = ctrl.get("lane_index", -1) if lane_idx in self.lane_dict: self.control_lanes.add(lane_idx) else: self.logger.warning(f"管制规则 lane_index={lane_idx} 不存在,跳过(当前有效车道:{list(self.lane_dict.keys())})") def get_invalid_lanes(self): return self.accident_lanes.union(self.control_lanes) def is_peak_hour(self): return self.peak_hour def get_weather(self): return self.weather def get_sorted_junction_ids(self): return sorted(self.junction_dict.keys())根据你的分析将此特征处理全部修改后给我

# agent/autonomous_agent.py from concurrent.futures import ThreadPoolExecutor from core.config import get_config from core.subsystem_registry import SubsystemRegistry from core.circuit_breaker import CircuitBreakerRegistry # 获取配置 system_config = get_config() class AutonomousAgent: def __init__(self): # 使用配置系统获取 MAX_WORKERS max_workers = system_config.get("MAX_WORKERS", max(1, os.cpu_count() * 2)) self.executor = ThreadPoolExecutor(max_workers=max_workers) # 获取子系统注册表 self.registry = SubsystemRegistry() # 初始化熔断器 self.circuit_breaker = CircuitBreakerRegistry.get_breaker( "autonomous_agent", failure_threshold=system_config.get("AGENT_FAILURE_THRESHOLD", 5), recovery_timeout=system_config.get("AGENT_RECOVERY_TIMEOUT", 60) ) def initialize(self): """初始化智能体""" # 确保所有子系统已初始化 if not self.registry.initialized: self.registry.initialize_all() # 获取关键子系统 self.hardware = self.registry.get("hardware_manager") self.scheduler = self.registry.get("life_scheduler") self.memory = self.registry.get("memory_manager") # 初始化完成 print("AutonomousAgent 初始化完成") def run(self): """运行智能体主循环""" try: while True: # 使用熔断器保护关键操作 self.circuit_breaker.call(self._run_cycle) except KeyboardInterrupt: print("智能体运行终止") except Exception as e: print(f"智能体运行错误: {str(e)}") # 错误处理逻辑... def _run_cycle(self): """执行单个运行周期""" # 获取下一个任务 task = self.scheduler.get_next_task() # 在线程池中执行任务 future = self.executor.submit(self.hardware.execute, task) future.add_done_callback(self._task_completed) def _task_completed(self, future): """任务完成回调""" try: result = future.result() # 处理任务结果... except Exception as e: print(f"任务执行失败: {str(e)}") # 错误处理逻辑... # 子系统注册 @SubsystemRegistry.subsystem("autonomous_agent", dependencies=["hardware_manager", "life_scheduler"]) class RegisteredAutonomousAgent(AutonomousAgent): """注册为子系统的智能体""" import os import sys import time import json import logging import traceback import threading import platform import psutil from pathlib import Path from typing import Any, Dict, Optional, Callable from concurrent.futures import ThreadPoolExecutor, Future # 确保项目根目录在 sys.path 中 BASE_DIR = Path(__file__).resolve().parent.parent.parent # 指向 E:\AI_System if str(BASE_DIR) not in sys.path: sys.path.insert(0, str(BASE_DIR)) # 导入核心模块 from core.config import system_config from core.exceptions import DependencyError, SubsystemFailure, ConfigurationError from core.metrics import MetricsCollector from core.circuit_breaker import CircuitBreaker from core.subsystem_registry import SubsystemRegistry # 全局线程池 executor = ThreadPoolExecutor(max_workers=system_config.MAX_WORKERS) class AutonomousAgent: def __init__(self): """重构后的自主智能体核心类,负责协调所有子系统""" self.logger = self._setup_logger() self.logger.info("🚀 初始化自主智能体核心模块...") self._running = False self._background_thread = None # 初始化状态跟踪 self.initialization_steps = [] self._last_env_check = 0 self._initialization_time = time.time() self.metrics = MetricsCollector() # 熔断器管理器 self.circuit_breakers = {} # 子系统注册表 self.subsystem_registry = SubsystemRegistry() # 环境管理器(外部设置) self.environment = None # 确保必要目录存在 self._ensure_directories_exist() try: # 初始化步骤 self._record_step("验证配置") self._validate_configuration() self._record_step("加载环境变量") self._load_environment() self._record_step("验证环境") self.verify_environment() self._record_step("初始化核心组件") self._initialize_core_components() self._record_step("初始化子系统") self._initialize_subsystems() self.logger.info(f"✅ 自主智能体初始化完成 (耗时: {time.time() - self._initialization_time:.2f}秒)") self.logger.info(f"初始化步骤: {', '.join(self.initialization_steps)}") except Exception as e: self.logger.exception(f"❌ 智能体初始化失败: {str(e)}") self.logger.error(f"堆栈跟踪:\n{traceback.format_exc()}") raise RuntimeError(f"智能体初始化失败: {str(e)}") from e def _setup_logger(self) -> logging.Logger: """配置日志记录器""" logger = logging.getLogger('AutonomousAgent') logger.setLevel(system_config.LOG_LEVEL) # 创建控制台处理器 console_handler = logging.StreamHandler() console_handler.setLevel(system_config.LOG_LEVEL) # 创建文件处理器 log_file = Path(system_config.LOG_DIR) / 'autonomous_agent.log' log_file.parent.mkdir(parents=True, exist_ok=True) file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setLevel(system_config.LOG_LEVEL) # 创建格式化器 formatter = logging.Formatter( '%(asctime)s [%(levelname)s] %(name)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) console_handler.setFormatter(formatter) file_handler.setFormatter(formatter) # 添加处理器 logger.addHandler(console_handler) logger.addHandler(file_handler) logger.propagate = False return logger def _ensure_directories_exist(self): """确保所需目录存在""" required_dirs = [ system_config.LOG_DIR, system_config.CONFIG_DIR, system_config.MODEL_CACHE_DIR ] for dir_path in required_dirs: try: if not isinstance(dir_path, Path): dir_path = Path(dir_path) if not dir_path.exists(): dir_path.mkdir(parents=True, exist_ok=True) self.logger.info(f"创建目录: {dir_path}") except Exception as e: self.logger.error(f"创建目录失败 {dir_path}: {str(e)}") def _validate_configuration(self): """验证关键配置项""" required_configs = [ 'LOG_DIR', 'CONFIG_DIR', 'MODEL_CACHE_DIR', 'MAX_WORKERS', 'AGENT_RESPONSE_TIMEOUT' ] missing = [] for config_key in required_configs: if not hasattr(system_config, config_key): missing.append(config_key) if missing: raise ConfigurationError(f"缺失关键配置项: {', '.join(missing)}") # 检查配置值有效性 if system_config.MAX_WORKERS <= 0: raise ConfigurationError(f"无效的MAX_WORKERS值: {system_config.MAX_WORKERS}") def _record_step(self, step_name: str): """记录初始化步骤""" self.initialization_steps.append(step_name) self.logger.info(f"⏳ 步骤 {len(self.initialization_steps)}: {step_name}") def _load_environment(self): """加载环境变量""" env_path = system_config.CONFIG_DIR / ".env" if not env_path.exists(): self.logger.warning(f"⚠️ 环境变量文件不存在: {env_path}") return try: from dotenv import load_dotenv load_dotenv(env_path) self.logger.info(f"✅ 已加载环境变量文件: {env_path}") except ImportError: self.logger.warning("dotenv包未安装,跳过环境变量加载。请安装: pip install python-dotenv") except Exception as e: self.logger.error(f"加载环境变量失败: {str(e)}") def set_environment(self, env_manager): """设置环境管理器引用""" self.environment = env_manager self.logger.info("✅ 已连接环境管理器") # 注册环境监控任务 if self.environment: self.subsystem_registry.register_task( "环境监控", self._monitor_environment, interval=system_config.get('ENVIRONMENT_MONITOR_INTERVAL', 5.0) ) def start(self): """启动智能体后台任务""" if not self._running: self._start_background_tasks() self.logger.info("🏁 智能体后台任务已启动") else: self.logger.warning("智能体已在运行中") def _start_background_tasks(self): """启动后台任务线程""" if self._running: return self._running = True self._background_thread = threading.Thread( target=self._background_task_loop, daemon=True, name="AutonomousAgentBackgroundTasks" ) self._background_thread.start() self.logger.info("✅ 后台任务线程已启动") def _background_task_loop(self): """后台任务循环""" self.logger.info("🔄 后台任务循环启动") while self._running: try: start_time = time.time() # 执行注册的周期性任务 self.subsystem_registry.run_periodic_tasks() # 动态调整睡眠时间 task_time = time.time() - start_time sleep_time = max(0.1, system_config.AGENT_TASK_INTERVAL - task_time) time.sleep(sleep_time) except Exception as e: self.logger.error(f"后台任务错误: {str(e)}") self.metrics.record_error('background_task') time.sleep(30) # 错误后等待更长时间 def verify_environment(self): """验证运行环境是否满足要求""" # 检查必需模块 required_modules = [ 'os', 'sys', 'logging', 'flask', 'werkzeug', 'numpy', 'transformers', 'torch', 'psutil' ] # 检查必需包 required_packages = [ ('dotenv', 'python-dotenv'), ('flask_socketio', 'flask-socketio') ] missing_modules = [] for mod in required_modules: try: __import__(mod) except ImportError: missing_modules.append(mod) missing_packages = [] for import_name, pkg_name in required_packages: try: __import__(import_name) except ImportError: missing_packages.append(pkg_name) # 处理缺失项 errors = [] if missing_modules: errors.append(f"缺失Python模块: {', '.join(missing_modules)}") if missing_packages: errors.append(f"缺失Python包: {', '.join(missing_packages)}") if errors: error_msg = "环境验证失败:\n" + "\n".join(errors) self.logger.error(error_msg) raise DependencyError(error_msg) self.logger.info("✅ 环境验证通过") def _log_environment_status(self): """记录环境状态信息""" try: # 获取系统信息 sys_info = { "os": platform.system(), "os_version": platform.version(), "cpu": platform.processor(), "cpu_cores": psutil.cpu_count(logical=False), "memory_total": round(psutil.virtual_memory().total / (1024 ** 3), 1), "memory_used": round(psutil.virtual_memory().used / (1024 ** 3), 1), "disk_total": round(psutil.disk_usage('/').total / (1024 ** 3), 1), "disk_used": round(psutil.disk_usage('/').used / (1024 ** 3), 1), } self.logger.info( f"📊 系统状态: OS={sys_info['os']} {sys_info['os_version']}, " f"CPU={sys_info['cpu']} ({sys_info['cpu_cores']}核), " f"内存={sys_info['memory_used']}/{sys_info['memory_total']}GB, " f"磁盘={sys_info['disk_used']}/{sys_info['disk_total']}GB" ) except Exception as e: self.logger.error(f"环境状态获取失败: {str(e)}") self.metrics.record_error('environment_status') def _initialize_core_components(self): """初始化不依赖其他组件的核心组件""" self._log_environment_status() # 初始化熔断器 self._initialize_circuit_breakers() # 注册核心任务 self.subsystem_registry.register_task( "子系统心跳检查", self._check_subsystem_heartbeats, interval=system_config.get('HEARTBEAT_INTERVAL', 60.0) ) self.subsystem_registry.register_task( "子系统恢复", self._recover_failed_subsystems, interval=system_config.get('RECOVERY_INTERVAL', 300.0) ) def _initialize_circuit_breakers(self): """为所有子系统初始化熔断器""" subsystems = [ '健康系统', '模型管理器', '记忆系统', '情感系统', '认知架构', '通信系统' ] for subsystem in subsystems: breaker = CircuitBreaker( failure_threshold=system_config.get('CIRCUIT_BREAKER_THRESHOLD', 5), recovery_timeout=system_config.get('CIRCUIT_BREAKER_TIMEOUT', 300) ) self.circuit_breakers[subsystem] = breaker self.logger.info(f"⚡ 为 {subsystem} 初始化熔断器") def _initialize_subsystems(self): """初始化所有子系统""" # 定义子系统初始化顺序 subsystems = [ ('健康系统', self._create_health_system, {}), ('模型管理器', self._create_model_manager, {}), ('记忆系统', self._create_memory_system, {}), ('情感系统', self._create_affective_system, {}), ('认知架构', self._create_cognitive_architecture, {}), ('通信系统', self._create_communication_system, {}) ] # 注册子系统依赖关系 dependencies = { '通信系统': ['认知架构'], '情感系统': ['健康系统', '记忆系统'], '认知架构': ['记忆系统'] } for name, creator_func, kwargs in subsystems: try: # 检查依赖是否满足 if name in dependencies: missing_deps = [dep for dep in dependencies[name] if not self.subsystem_registry.get_subsystem(dep)] if missing_deps: self.logger.warning(f"⚠️ 子系统 {name} 缺少依赖: {', '.join(missing_deps)}") # 尝试自动初始化缺失依赖 for dep in missing_deps: self._initialize_dependency(dep) # 创建实例 instance = creator_func(**kwargs) self.subsystem_registry.register_subsystem(name, instance) # 注册子系统任务 if hasattr(instance, 'periodic_task'): self.subsystem_registry.register_task( f"{name}更新", instance.periodic_task, interval=system_config.get(f'{name}_INTERVAL', 60.0) ) self.logger.info(f"✅ {name}初始化完成") except Exception as e: self.logger.error(f"❌ {name}初始化失败: {str(e)}") self.metrics.record_error(f'subsystem_init_{name.lower()}') def _initialize_dependency(self, subsystem_name: str): """初始化依赖子系统""" creators = { '健康系统': self._create_health_system, '模型管理器': self._create_model_manager, '记忆系统': self._create_memory_system, '情感系统': self._create_affective_system, '认知架构': self._create_cognitive_architecture, '通信系统': self._create_communication_system } if subsystem_name in creators: try: instance = creators[subsystem_name]() self.subsystem_registry.register_subsystem(subsystem_name, instance) self.logger.info(f"✅ 依赖子系统 {subsystem_name} 初始化完成") except Exception as e: self.logger.error(f"❌ 依赖子系统 {subsystem_name} 初始化失败: {str(e)}") raise # 各子系统实现(增强功能) def _create_health_system(self): class HealthSystem: def __init__(self): self.status = "healthy" self.metrics = {} self.logger = logging.getLogger('HealthSystem') def periodic_task(self): """更新健康状态""" try: # 获取系统状态 cpu_usage = psutil.cpu_percent() mem_usage = psutil.virtual_memory().percent disk_usage = psutil.disk_usage('/').percent # 更新状态 self.status = "healthy" if cpu_usage < 90 and mem_usage < 90 else "warning" self.metrics = { "cpu_usage": cpu_usage, "mem_usage": mem_usage, "disk_usage": disk_usage, "timestamp": time.time() } self.logger.debug(f"健康状态更新: {self.status}") except Exception as e: self.logger.error(f"健康系统更新失败: {str(e)}") def record_environment_status(self, env_data): """记录环境状态""" self.metrics['environment'] = env_data def get_status(self): return { "status": self.status, "metrics": self.metrics } return HealthSystem() def _create_model_manager(self): class ModelManager: def __init__(self): self.loaded_models = {} self.logger = logging.getLogger('ModelManager') def load_model(self, model_name): """加载模型""" if model_name not in self.loaded_models: # 模拟模型加载 self.logger.info(f"加载模型: {model_name}") self.loaded_models[model_name] = { "status": "loaded", "load_time": time.time() } return True return False def periodic_task(self): """模型管理器周期性任务""" # 检查模型状态 for model_name, model_info in list(self.loaded_models.items()): # 模拟模型验证 if time.time() - model_info['load_time'] > 86400: # 24小时 self.logger.info(f"重新加载模型: {model_name}") model_info['load_time'] = time.time() def get_status(self): return { "loaded_models": list(self.loaded_models.keys()), "count": len(self.loaded_models) } return ModelManager() def _create_memory_system(self): class MemorySystem: def __极忆__init__(self): self.memories = [] self.last_consolidation = time.time() self.logger = logging.getLogger('MemorySystem') def periodic_task(self): """巩固记忆""" try: # 保留最近100条记忆 if len(self.memories) > 100: self.memories = self.memories[-100:] self.last_consolidation = time.time() self.logger.debug(f"记忆巩固完成,当前记忆数: {极忆len(self.memories)}") except Exception as e: self.logger.error(f"记忆巩固失败: {str(e)}") def add_memory(self, memory): """添加记忆""" self.memories.append({ "content": memory, "timestamp": time.time() }) def get_status(self): return { "memory_count": len(self.memories), "last_consolidation": self.last_consolidation } return MemorySystem() def _create_affective_system(self): class AffectiveSystem: def __init__(self): self.mood = "neutral" self.energy = 100 self.logger = logging.getLogger('AffectiveSystem') def periodic_task(self): """情感成长""" try: # 根据时间恢复能量 self.energy = min(100, self.energy + 1) self.logger.debug(f"情感更新: 能量={self.energy}, 情绪={self.mood}") except Exception as e: self.logger.error(f"情感系统更新失败: {str(e)}") def update_mood(self, interaction): """根据交互更新情绪""" if "positive" in interaction: self.mood = "happy" elif "negative" in interaction: self.mood = "sad" def get_status(self): return { "mood": self.mood, "energy": self.energy } return AffectiveSystem() def _create_cognitive_architecture(self): class CognitiveArchitecture: def __init__(self): self.current_task = None self.task_history = [] self.logger = logging.getLogger('CognitiveArchitecture') def start_task(self, task): """开始新任务""" self.logger.info(f"开始任务: {task}") self.current_task = task self.task_history.append({ "task": task, "start_time": time.time(), "status": "in_progress" }) def complete_task(self, result): """完成任务""" if self.current_task: for task in reversed(self.task_history): if task["task"] == self.current_task and task["status"] == "in_progress": task["status"] = "completed" task["result"] = result task["end_time"] = time.time() self.log极忆.info(f"完成任务: {task['task']}") break self.current_task = None def periodic_task(self): """认知架构周期性任务""" # 清理过时任务 now = time.time() self.task_history = [t for t in self.task_history if t['status'] == 'completed' or (now - t['start_time']) < 3600] # 保留1小时内进行中的任务 def get_status(self): return { "current_task": self.current_task, "task_count": len(self.task_history), "completed_tasks": sum(1 for t in self.task_history if t["status"] == "completed") } return CognitiveArchitecture() def _create_communication_system(self): class CommunicationSystem: def __init__(self): self.message_queue = [] self.processed_count = 0 self.logger = logging.getLogger('CommunicationSystem') def process_input(self, user_input: str, user_id: str = "default") -> str: """处理用户输入""" try: # 模拟处理逻辑 response = f"已处理您的消息: '{user_input}' (用户: {user_id})" # 记录处理 self.processed_count += 1 self.logger.info(f"处理消息: '{user_input[:30]}...' (用户: {user_id})") return response except Exception as e: self.logger.error(f"消息处理失败: {str(e)}") return "处理消息时出错" def periodic_task(self): """通信系统周期性任务""" # 清理消息队列 if len(self.message_queue) > 100: self.message_queue = self.message_queue[-100:] self.logger.debug("清理消息队列") def check_heartbeat(self): """心跳检查""" return True def get_status(self): return { "queue_size": len(self.message_queue), "processed_count": self.processed_count } return CommunicationSystem() def process_input(self, user_input: str, user_id: str = "default") -> Dict[str, Any]: """处理用户输入(通过通信系统)""" # 获取通信系统 comm_system = self.subsystem_registry.get_subsystem('通信系统') if not comm_system: self.logger.error("通信系统未初始化,使用回退处理") self.metrics.record_error('communication_system_inactive') return {"response": "系统正在维护中,请稍后再试"} # 检查熔断器状态 breaker = self.circuit_breakers.get('通信系统') if breaker and breaker.is_open(): self.logger.warning("通信系统熔断器已打开") self.metrics.record_error('communication_circuit_open') return {"response": "系统繁忙,请稍后再试"} try: # 使用熔断器包装调用 def process_wrapper(): return comm_system.process_input(user_input, user_id) if breaker: response = breaker.call(process_wrapper) else: response = process_wrapper() # 使用线程池异步处理 future = executor.submit(lambda: response) result = future.result(timeout=system_config.AGENT_RESPONSE_TIMEOUT) # 记录成功 self.metrics.record_success('process_input') return {"response": result} except TimeoutError: self.logger.warning("处理输入超时") self.metrics.record_timeout('process_input') if breaker: breaker.record_failure() return {"error": "处理超时,请重试"} except Exception as e: self.logger.error(f"处理输入失败: {str(e)}") self.metrics.record_error('process_input') if breaker: breaker.record极忆failure() return {"error": "处理失败,请稍后再试"} def _monitor_environment(self): """监控环境状态""" try: if self.environment and hasattr(self.environment, 'get_state'): # 使用真实环境管理器获取状态 env_state = self.environment.get_state() self.logger.info( f"🌡️ 环境监控: 温度={env_state.get('temperature', '未知')}℃, " f"湿度={env_state.get('humidity', '未知')}%, " f"光照={env_state.get('light_level', '未知')}%" ) # 记录到健康系统(如果可用) health_system = self.subsystem_registry.get_subsystem('健康系统') if health_system and hasattr(health_system, 'record_environment_status'): health_system.record_environment_status(env_state) else: # 使用内置监控 cpu_usage = psutil.cpu_percent() mem_usage = psutil.virtual_memory().percent disk_usage = psutil.disk_usage('/').percent self.logger.info( f"📊 系统监控: CPU={cpu_usage}%, " f"内存={mem_usage}%, " f"磁盘={disk_usage}%" ) # 记录到健康系统 health_system = self.subsystem_registry.get_subsystem('健康系统') if health_system and hasattr(health_system, 'record_environment_status'): health_system.record_environment_status({ "cpu_usage": cpu_usage, "mem_usage": mem_usage, "disk_usage": disk_usage }) except Exception as e: self.logger.error(f"环境监控失败: {str(e)}") self.metrics.record_error('environment_monitoring') def _check_subsystem_heartbeats(self): """检查子系统心跳""" for name, subsystem in self.subsystem_registry.subsystems.items(): if hasattr(subsystem, 'check_heartbeat'): try: if not subsystem.check_heartbeat(): self.logger.warning(f"⚠️ 子系统 {name} 心跳检测失败") self._handle_subsystem_error(name) else: self.logger.debug(f"✅ 子系统 {name} 心跳正常") except Exception as e: self.logger.error(f"子系统 {name} 心跳检查异常: {str(e)}") self._handle_subsystem_error(name) self.metrics.record_error(f'heartbeat_{name.lower()}') def _handle_subsystem_error(self, name: str): """处理子系统错误""" breaker = self.circuit_breakers.get(name) if breaker: breaker.record_failure() if breaker.is_open(): self.logger.critical(f"🚨 子系统 {name} 因连续错误被熔断!") self.metrics.record_event('circuit_breaker', name) def _recover_failed_subsystems(self): """尝试恢复失败的子系统""" for name, breaker in self.circuit_breakers.items(): if breaker.is_open() and breaker.should_try_recovery(): self.logger.info(f"🔄 尝试恢复子系统: {name}") try: # 尝试重新初始化子系统 self._reinitialize_subsystem(name) breaker.record_success() self.logger.info(f"✅ 子系统 {name} 恢复成功") self.metrics.record_event('subsystem_recovered', name) except Exception as e: self.logger.error(f"子系统 {name} 恢复失败: {str(e)}") breaker.record_failure() self.metrics.record_error(f'recovery_{name.lower()}') def _reinitialize_subsystem(self, name: str): """重新初始化子系统""" creators = { '健康系统': self._create_health_system, '模型管理器': self._create_model_manager, '记忆系统': self._create_memory_system, '情感系统': self._create_affective_system, '认知架构': self._create_cognitive_architecture, '通信系统': self._create_communication_system } if name in creators: instance = creators[name]() self.subsystem_registry.register_subsystem(name, instance) else: raise SubsystemFailure(f"未知子系统: {name}") def get_status(self) -> Dict[str, Any]: """获取智能体状态报告""" status_data = { "uptime": time.time() - self._initialization_time, "running": self._running, "metrics": self.metrics.get_metrics(), "subsystems": {} } # 添加子系统状态 for name, subsystem in self.subsystem_registry.subsystems.items(): if hasattr(subsystem, 'get_status'): status_data['subsystems'][name] = subsystem.get_status() # 添加熔断器状态 status_data['circuit_breakers'] = {} for name, breaker in self.circuit_breakers.items(): status_data['circuit_breakers'][name] = breaker.get_status() return status_data def shutdown(self): """关闭智能体""" self.logger.info("🛑 正在关闭智能体...") self._running = False # 停止线程池 executor.shutdown(wait=False) # 等待后台线程 if self._background_thread and self._background_thread.is_alive(): self._background_thread.join(timeout=5.0) if self._background_thread.is_alive(): self.logger.warning("后台线程未正常退出") self.logger.info("✅ 智能体已关闭")

# agent/autonomous_agent.py import os import sys import time import json import logging import traceback import threading import platform import psutil from pathlib import Path from typing import Any, Dict, Optional, List, Callable, Tuple from concurrent.futures import ThreadPoolExecutor, Future, TimeoutError # 原错误导入 from core.config import system_config # 修改为正确导入 from core.config import CoreConfig system_config = CoreConfig() # 创建配置实例 # 确保项目根目录在 sys.path 中 BASE_DIR = Path(__file__).resolve().parent.parent.parent # 指向 E:\AI_System if str(BASE_DIR) not in sys.path: sys.path.insert(0, str(BASE_DIR)) # 导入核心模块 from core.config import system_config from core.exceptions import DependencyError, SubsystemFailure, ConfigurationError from core.metrics import MetricsCollector from core.circuit_breaker import CircuitBreaker from core.subsystem_registry import SubsystemRegistry # 全局线程池 executor = ThreadPoolExecutor(max_workers=system_config.MAX_WORKERS) class AutonomousAgent: def __init__(self): """重构后的自主智能体核心类,负责协调所有子系统""" self.logger = self._setup_logger() self.logger.info("🚀 初始化自主智能体核心模块...") self._running = False self._background_thread = None # 初始化状态跟踪 self.initialization_steps = [] self._last_env_check = 0 self._initialization_time = time.time() self.metrics = MetricsCollector() # 熔断器管理器 self.circuit_breakers: Dict[str, CircuitBreaker] = {} # 子系统注册表 self.subsystem_registry = SubsystemRegistry() # 环境管理器(外部设置) self.environment = None # 确保必要目录存在 self._ensure_directories_exist() try: # 初始化步骤 self._record_step("验证配置") self._validate_configuration() self._record_step("加载环境变量") self._load_environment() self._record_step("验证环境") self.verify_environment() self._record_step("初始化核心组件") self._initialize_core_components() self._record_step("初始化子系统") self._initialize_subsystems() self.logger.info(f"✅ 自主智能体初始化完成 (耗时: {time.time() - self._initialization_time:.2f}秒)") self.logger.info(f"初始化步骤: {', '.join(self.initialization_steps)}") except Exception as e: self.logger.exception(f"❌ 智能体初始化失败: {str(e)}") self.logger.error(f"堆栈跟踪:\n{traceback.format_exc()}") raise RuntimeError(f"智能体初始化失败: {str(e)}") from e def _setup_logger(self) -> logging.Logger: """配置日志记录器""" logger = logging.getLogger('AutonomousAgent') logger.setLevel(system_config.LOG_LEVEL) # 创建控制台处理器 console_handler = logging.StreamHandler() console_handler.setLevel(system_config.LOG_LEVEL) # 创建文件处理器 log_file = Path(system_config.LOG_DIR) / 'autonomous_agent.log' log_file.parent.mkdir(parents=True, exist_ok=True) file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setLevel(system_config.LOG_LEVEL) # 创建格式化器 formatter = logging.Formatter( '%(asctime)s [%(levelname)s] %(name)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) console_handler.setFormatter(formatter) file_handler.setFormatter(formatter) # 添加处理器 logger.addHandler(console_handler) logger.addHandler(file_handler) logger.propagate = False return logger def _ensure_directories_exist(self): """确保所需目录存在""" required_dirs = [ system_config.LOG_DIR, system_config.CONFIG_DIR, system_config.MODEL_CACHE_DIR ] for dir_path in required_dirs: try: if not isinstance(dir_path, Path): dir_path = Path(dir_path) if not dir_path.exists(): dir_path.mkdir(parents=True, exist_ok=True) self.logger.info(f"创建目录: {dir_path}") except Exception as e: self.logger.error(f"创建目录失败 {dir_path}: {str(e)}") def _validate_configuration(self): """验证关键配置项""" required_configs = [ 'LOG_DIR', 'CONFIG_DIR', 'MODEL_CACHE_DIR', 'MAX_WORKERS', 'AGENT_RESPONSE_TIMEOUT' ] missing = [] for config_key in required_configs: if not hasattr(system_config, config_key): missing.append(config_key) if missing: raise ConfigurationError(f"缺失关键配置项: {', '.join(missing)}") # 检查配置值有效性 if system_config.MAX_WORKERS <= 0: raise ConfigurationError(f"无效的MAX_WORKERS值: {system_config.MAX_WORKERS}") def _record_step(self, step_name: str): """记录初始化步骤""" self.initialization_steps.append(step_name) self.logger.info(f"⏳ 步骤 {len(self.initialization_steps)}: {step_name}") def _load_environment(self): """加载环境变量""" env_path = system_config.CONFIG_DIR / ".env" if not env_path.exists(): self.logger.warning(f"⚠️ 环境变量文件不存在: {env_path}") return try: from dotenv import load_dotenv load_dotenv(env_path) self.logger.info(f"✅ 已加载环境变量文件: {env_path}") except ImportError: self.logger.warning("dotenv包未安装,跳过环境变量加载。请安装: pip install python-dotenv") except Exception as e: self.logger.error(f"加载环境变量失败: {str(e)}") def set_environment(self, env_manager): """设置环境管理器引用""" self.environment = env_manager self.logger.info("✅ 已连接环境管理器") # 注册环境监控任务 if self.environment: self.subsystem_registry.register_task( "环境监控", self._monitor_environment, interval=system_config.get('ENVIRONMENT_MONITOR_INTERVAL', 5.0) ) def start(self): """启动智能体后台任务""" if not self._running: self._start_background_tasks() self.logger.info("🏁 智能体后台任务已启动") else: self.logger.warning("智能体已在运行中") def _start_background_tasks(self): """启动后台任务线程""" if self._running: return self._running = True self._background_thread = threading.Thread( target=self._background_task_loop, daemon=True, name="AutonomousAgentBackgroundTasks" ) self._background_thread.start() self.logger.info("✅ 后台任务线程已启动") def _background_task_loop(self): """后台任务循环""" self.logger.info("🔄 后台任务循环启动") while self._running: try: start_time = time.time() # 执行注册的周期性任务 self.subsystem_registry.run_periodic_tasks() # 动态调整睡眠时间 task_time = time.time() - start_time sleep_time = max(0.1, system_config.AGENT_TASK_INTERVAL - task_time) time.sleep(sleep_time) except Exception as e: self.logger.error(f"后台任务错误: {str(e)}") self.metrics.record_error('background_task') time.sleep(30) # 错误后等待更长时间 def verify_environment(self): """验证运行环境是否满足要求""" # 检查必需模块 required_modules = [ 'os', 'sys', 'logging', 'flask', 'werkzeug', 'numpy', 'transformers', 'torch', 'psutil' ] # 检查必需包 required_packages = [ ('dotenv', 'python-dotenv'), ('flask_socketio', 'flask-socketio') ] missing_modules = [] for mod in required_modules: try: __import__(mod) except ImportError: missing_modules.append(mod) missing_packages = [] for import_name, pkg_name in required_packages: try: __import__(import_name) except ImportError: missing_packages.append(pkg_name) # 处理缺失项 errors = [] if missing_modules: errors.append(f"缺失Python模块: {', '.join(missing_modules)}") if missing_packages: errors.append(f"缺失Python包: {', '.join(missing_packages)}") if errors: error_msg = "环境验证失败:\n" + "\n".join(errors) self.logger.error(error_msg) raise DependencyError(error_msg) self.logger.info("✅ 环境验证通过") def _log_environment_status(self): """记录环境状态信息""" try: # 获取系统信息 sys_info = { "os": platform.system(), "os_version": platform.version(), "cpu": platform.processor(), "cpu_cores": psutil.cpu_count(logical=False), "memory_total": round(psutil.virtual_memory().total / (1024 ** 3), 1), "memory_used": round(psutil.virtual_memory().used / (1024 ** 3), 1), "disk_total": round(psutil.disk_usage('/').total / (1024 ** 3), 1), "disk_used": round(psutil.disk_usage('/').used / (1024 ** 3), 1), } self.logger.info( f"📊 系统状态: OS={sys_info['os']} {sys_info['os_version']}, " f"CPU={sys_info['cpu']} ({sys_info['cpu_cores']}核), " f"内存={sys_info['memory_used']}/{sys_info['memory_total']}GB, " f"磁盘={sys_info['disk_used']}/{sys_info['disk_total']}GB" ) except Exception as e: self.logger.error(f"环境状态获取失败: {str(e)}") self.metrics.record_error('environment_status') def _initialize_core_components(self): """初始化不依赖其他组件的核心组件""" self._log_environment_status() # 初始化熔断器 self._initialize_circuit_breakers() # 注册核心任务 self.subsystem_registry.register_task( "子系统心跳检查", self._check_subsystem_heartbeats, interval=system_config.get('HEARTBEAT_INTERVAL', 60.0) ) self.subsystem_registry.register_task( "子系统恢复", self._recover_failed_subsystems, interval=system_config.get('RECOVERY_INTERVAL', 300.0) ) def _initialize_circuit_breakers(self): """为所有子系统初始化熔断器""" subsystems = [ '健康系统', '模型管理器', '记忆系统', '情感系统', '认知架构', '通信系统' ] for subsystem in subsystems: breaker = CircuitBreaker( failure_threshold=system_config.get('CIRCUIT_BREAKER_THRESHOLD', 5), recovery_timeout=system_config.get('CIRCUIT_BREAKER_TIMEOUT', 300) ) self.circuit_breakers[subsystem] = breaker self.logger.info(f"⚡ 为 {subsystem} 初始化熔断器") def _initialize_subsystems(self): """初始化所有子系统""" # 定义子系统初始化顺序 subsystems = [ ('健康系统', self._create_health_system, {}), ('模型管理器', self._create_model_manager, {}), ('记忆系统', self._create_memory_system, {}), ('情感系统', self._create_affective_system, {}), ('认知架构', self._create_cognitive_architecture, {}), ('通信系统', self._create_communication_system, {}) ] # 注册子系统依赖关系 dependencies = { '通信系统': ['认知架构'], '情感系统': ['健康系统', '记忆系统'], '认知架构': ['记忆系统'] } for name, creator_func, kwargs in subsystems: try: # 检查依赖是否满足 if name in dependencies: missing_deps = [dep for dep in dependencies[name] if not self.subsystem_registry.get_subsystem(dep)] if missing_deps: self.logger.warning(f"⚠️ 子系统 {name} 缺少依赖: {', '.join(missing_deps)}") # 尝试自动初始化缺失依赖 for dep in missing_deps: self._initialize_dependency(dep) # 创建实例 instance = creator_func(**kwargs) self.subsystem_registry.register_subsystem(name, instance) # 注册子系统任务 if hasattr(instance, 'periodic_task'): self.subsystem_registry.register_task( f"{name}更新", instance.periodic_task, interval=system_config.get(f'{name}_INTERVAL', 60.0) ) self.logger.info(f"✅ {name}初始化完成") except Exception as e: self.logger.error(f"❌ {name}初始化失败: {str(e)}") self.metrics.record_error(f'subsystem_init_{name.lower()}') def _initialize_dependency(self, subsystem_name: str): """初始化依赖子系统""" creators = { '健康系统': self._create_health_system, '模型管理器': self._create_model_manager, '记忆系统': self._create_memory_system, '情感系统': self._create_affective_system, '认知架构': self._create_cognitive_architecture, '通信系统': self._create_communication_system } if subsystem_name in creators: try: instance = creators[subsystem_name]() self.subsystem_registry.register_subsystem(subsystem_name, instance) self.logger.info(f"✅ 依赖子系统 {subsystem_name} 初始化完成") except Exception as e: self.logger.error(f"❌ 依赖子系统 {subsystem_name} 初始化失败: {str(e)}") raise # 各子系统实现(增强功能) def _create_health_system(self): class HealthSystem: def __init__(self): self.status = "healthy" self.metrics: Dict[str, Any] = {} self.logger = logging.getLogger('HealthSystem') def periodic_task(self): """更新健康状态""" try: # 获取系统状态 cpu_usage = psutil.cpu_percent() mem_usage = psutil.virtual_memory().percent disk_usage = psutil.disk_usage('/').percent # 更新状态 self.status = "healthy" if cpu_usage < 90 and mem_usage < 90 else "warning" self.metrics = { "cpu_usage": cpu_usage, "mem_usage": mem_usage, "disk_usage": disk_usage, "timestamp": time.time() } self.logger.debug(f"健康状态更新: {self.status}") except Exception as e: self.logger.error(f"健康系统更新失败: {str(e)}") def record_environment_status(self, env_data: Dict[str, Any]): """记录环境状态""" self.metrics['environment'] = env_data def get_status(self) -> Dict[str, Any]: return { "status": self.status, "metrics": self.metrics } return HealthSystem() def _create_model_manager(self): class ModelManager: def __init__(self): self.loaded_models: Dict[str, Dict[str, Any]] = {} self.logger = logging.getLogger('ModelManager') def load_model(self, model_name: str) -> bool: """加载模型""" if model_name not in self.loaded_models: # 模拟模型加载 self.logger.info(f"加载模型: {model_name}") self.loaded_models[model_name] = { "status": "loaded", "load_time": time.time() } return True return False def periodic_task(self): """模型管理器周期性任务""" # 检查模型状态 for model_name, model_info in list(self.loaded_models.items()): # 模拟模型验证 if time.time() - model_info['load_time'] > 86400: # 24小时 self.logger.info(f"重新加载模型: {model_name}") model_info['load_time'] = time.time() def get_status(self) -> Dict[str, Any]: return { "loaded_models": list(self.loaded_models.keys()), "count": len(self.loaded_models) } return ModelManager() def _create_memory_system(self): class MemorySystem: def __init__(self): self.memories: List[Dict[str, Any]] = [] self.last_consolidation = time.time() self.logger = logging.getLogger('MemorySystem') def periodic_task(self): """巩固记忆""" try: # 保留最近100条记忆 if len(self.memories) > 100: self.memories = self.memories[-100:] self.last_consolidation = time.time() self.logger.debug(f"记忆巩固完成,当前记忆数: {len(self.memories)}") except Exception as e: self.logger.error(f"记忆巩固失败: {str(e)}") def add_memory(self, memory: str): """添加记忆""" self.memories.append({ "content": memory, "timestamp": time.time() }) def get_status(self) -> Dict[str, Any]: return { "memory_count": len(self.memories), "last_consolidation": self.last_consolidation } return MemorySystem() def _create_affective_system(self): class AffectiveSystem: def __init__(self): self.mood = "neutral" self.energy = 100 self.logger = logging.getLogger('AffectiveSystem') def periodic_task(self): """情感成长""" try: # 根据时间恢复能量 self.energy = min(100, self.energy + 1) self.logger.debug(f"情感更新: 能量={self.energy}, 情绪={self.mood}") except Exception as e: self.logger.error(f"情感系统更新失败: {str(e)}") def update_mood(self, interaction: str): """根据交互更新情绪""" if "positive" in interaction: self.mood = "happy" elif "negative" in interaction: self.mood = "sad" def get_status(self) -> Dict[str, Any]: return { "mood": self.mood, "energy": self.energy } return AffectiveSystem() def _create_cognitive_architecture(self): class CognitiveArchitecture: def __init__(self): self.current_task: Optional[str] = None self.task_history: List[Dict[str, Any]] = [] self.logger = logging.getLogger('CognitiveArchitecture') def start_task(self, task: str): """开始新任务""" self.logger.info(f"开始任务: {task}") self.current_task = task self.task_history.append({ "task": task, "start_time": time.time(), "status": "in_progress" }) def complete_task(self, result: Any): """完成任务""" if self.current_task: for task in reversed(self.task_history): if task["task"] == self.current_task and task["status"] == "in_progress": task["status"] = "completed" task["result"] = result task["end_time"] = time.time() self.logger.info(f"完成任务: {task['task']}") break self.current_task = None def periodic_task(self): """认知架构周期性任务""" # 清理过时任务 now = time.time() self.task_history = [t for t in self.task_history if t['status'] == 'completed' or (now - t['start_time']) < 3600] # 保留1小时内进行中的任务 def get_status(self) -> Dict[str, Any]: return { "current_task": self.current_task, "task_count": len(self.task_history), "completed_tasks": sum(1 for t in self.task_history if t["status"] == "completed") } return CognitiveArchitecture() def _create_communication_system(self): class CommunicationSystem: def __init__(self): self.message_queue: List[Dict[str, Any]] = [] self.processed_count = 0 self.logger = logging.getLogger('CommunicationSystem') def process_input(self, user_input: str, user_id: str = "default") -> str: """处理用户输入""" try: # 模拟处理逻辑 response = f"已处理您的消息: '{user_input}' (用户: {user_id})" # 记录处理 self.processed_count += 1 self.logger.info(f"处理消息: '{user_input[:30]}...' (用户: {user_id})") return response except Exception as e: self.logger.error(f"消息处理失败: {str(e)}") return "处理消息时出错" def periodic_task(self): """通信系统周期性任务""" # 清理消息队列 if len(self.message_queue) > 100: self.message_queue = self.message_queue[-100:] self.logger.debug("清理消息队列") def check_heartbeat(self) -> bool: """心跳检查""" return True def get_status(self) -> Dict[str, Any]: return { "queue_size": len(self.message_queue), "processed_count": self.processed_count } return CommunicationSystem() def process_input(self, user_input: str, user_id: str = "default") -> Dict[str, Any]: """处理用户输入(通过通信系统)""" # 获取通信系统 comm_system = self.subsystem_registry.get_subsystem('通信系统') if not comm_system: self.logger.error("通信系统未初始化,使用回退处理") self.metrics.record_error('communication_system_inactive') return {"response": "系统正在维护中,请稍后再试"} # 检查熔断器状态 breaker = self.circuit_breakers.get('通信系统') if breaker and breaker.is_open(): self.logger.warning("通信系统熔断器已打开") self.metrics.record_error('communication_circuit_open') return {"response": "系统繁忙,请稍后再试"} try: # 使用熔断器包装调用 def process_wrapper(): return comm_system.process_input(user_input, user_id) if breaker: response = breaker.call(process_wrapper) else: response = process_wrapper() # 使用线程池异步处理 future = executor.submit(lambda: response) result = future.result(timeout=system_config.AGENT_RESPONSE_TIMEOUT) # 记录成功 self.metrics.record_success('process_input') return {"response": result} except TimeoutError: self.logger.warning("处理输入超时") self.metrics.record_timeout('process_input') if breaker: breaker.record_failure() return {"error": "处理超时,请重试"} except Exception as e: self.logger.error(f"处理输入失败: {str(e)}") self.metrics.record_error('process_input') if breaker: breaker.record_failure() return {"error": "处理失败,请稍后再试"} def _monitor_environment(self): """监控环境状态""" try: if self.environment and hasattr(self.environment, 'get_state'): # 使用真实环境管理器获取状态 env_state = self.environment.get_state() self.logger.info( f"🌡️ 环境监控: 温度={env_state.get('temperature', '未知')}℃, " f"湿度={env_state.get('humidity', '未知')}%, " f"光照={env_state.get('light_level', '未知')}%" ) # 记录到健康系统(如果可用) health_system = self.subsystem_registry.get_subsystem('健康系统') if health_system and hasattr(health_system, 'record_environment_status'): health_system.record_environment_status(env_state) else: # 使用内置监控 cpu_usage = psutil.cpu_percent() mem_usage = psutil.virtual_memory().percent disk_usage = psutil.disk_usage('/').percent self.logger.info( f"📊 系统监控: CPU={cpu_usage}%, " f"内存={mem_usage}%, " f"磁盘={disk_usage}%" ) # 记录到健康系统 health_system = self.subsystem_registry.get_subsystem('健康系统') if health_system and hasattr(health_system, 'record_environment_status'): health_system.record_environment_status({ "cpu_usage": cpu_usage, "mem_usage": mem_usage, "disk_usage": disk_usage }) except Exception as e: self.logger.error(f"环境监控失败: {str(e)}") self.metrics.record_error('environment_monitoring') def _check_subsystem_heartbeats(self): """检查子系统心跳""" for name, subsystem in self.subsystem_registry.subsystems.items(): if hasattr(subsystem, 'check_heartbeat'): try: if not subsystem.check_heartbeat(): self.logger.warning(f"⚠️ 子系统 {name} 心跳检测失败") self._handle_subsystem_error(name) else: self.logger.debug(f"✅ 子系统 {name} 心跳正常") except Exception as e: self.logger.error(f"子系统 {name} 心跳检查异常: {str(e)}") self._handle_subsystem_error(name) self.metrics.record_error(f'heartbeat_{name.lower()}') def _handle_subsystem_error(self, name: str): """处理子系统错误""" breaker = self.circuit_breakers.get(name) if breaker: breaker.record_failure() if breaker.is_open(): self.logger.critical(f"🚨 子系统 {name} 因连续错误被熔断!") self.metrics.record_event('circuit_breaker', name) def _recover_failed_subsystems(self): """尝试恢复失败的子系统""" for name, breaker in self.circuit_breakers.items(): if breaker.is_open() and breaker.should_try_recovery(): self.logger.info(f"🔄 尝试恢复子系统: {name}") try: # 尝试重新初始化子系统 self._reinitialize_subsystem(name) breaker.record_success() self.logger.info(f"✅ 子系统 {name} 恢复成功") self.metrics.record_event('subsystem_recovered', name) except Exception as e: self.logger.error(f"子系统 {name} 恢复失败: {str(e)}") breaker.record_failure() self.metrics.record_error(f'recovery_{name.lower()}') def _reinitialize_subsystem(self, name: str): """重新初始化子系统""" creators = { '健康系统': self._create_health_system, '模型管理器': self._create_model_manager, '记忆系统': self._create_memory_system, '情感系统': self._create_affective_system, '认知架构': self._create_cognitive_architecture, '通信系统': self._create_communication_system } if name in creators: # 先尝试关闭现有实例 old_instance = self.subsystem_registry.get_subsystem(name) if old_instance and hasattr(old_instance, 'shutdown'): try: old_instance.shutdown() self.logger.info(f"已关闭旧实例: {name}") except Exception as e: self.logger.warning(f"关闭旧实例失败: {str(e)}") # 创建新实例 instance = creators[name]() self.subsystem_registry.register_subsystem(name, instance) else: raise SubsystemFailure(f"未知子系统: {name}") def get_status(self) -> Dict[str, Any]: """获取智能体状态报告""" status_data = { "uptime": time.time() - self._initialization_time, "running": self._running, "metrics": self.metrics.get_metrics(), "subsystems": {} } # 添加子系统状态 for name, subsystem in self.subsystem_registry.subsystems.items(): if hasattr(subsystem, 'get_status'): status_data['subsystems'][name] = subsystem.get_status() # 添加熔断器状态 status_data['circuit_breakers'] = {} for name, breaker in self.circuit_breakers.items(): status_data['circuit_breakers'][name] = breaker.get_status() return status_data def shutdown(self): """关闭智能体""" self.logger.info("🛑 正在关闭智能体...") self._running = False # 停止线程池 executor.shutdown(wait=False) # 等待后台线程 if self._background_thread and self._background_thread.is_alive(): self._background_thread.join(timeout=5.0) if self._background_thread.is_alive(): self.logger.warning("后台线程未正常退出") # 关闭所有子系统 for name, subsystem in self.subsystem_registry.subsystems.items(): if hasattr(subsystem, 'shutdown'): try: subsystem.shutdown() self.logger.info(f"已关闭子系统: {name}") except Exception as e: self.logger.error(f"关闭子系统 {name} 失败: {str(e)}") self.logger.info("✅ 智能体已关闭")

我真的 跟你说话 就好像说外语似的 人家本来就有这个叫config_loader.py大文件 你为什么要新建?“#E:\AI_System\config/config_loader.py import os import sys import logging import json from pathlib import Path from dotenv import load_dotenv class ConfigLoader: _instance = None _initialized = False # 防止重复初始化 def __new__(cls): """单例模式实现""" if cls._instance is None: cls._instance = super(ConfigLoader, cls).__new__(cls) return cls._instance def __init__(self): """初始化配置加载器(确保只初始化一次)""" if not self._initialized: self._initialized = True self._setup_logger() self._load_config() self._ensure_dirs() def _setup_logger(self): """配置日志记录器""" self.logger = logging.getLogger('ConfigLoader') self.logger.setLevel(logging.INFO) # 创建控制台处理器 ch = logging.StreamHandler() ch.setLevel(logging.INFO) # 创建格式化器 formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) # 添加处理器 self.logger.addHandler(ch) def _load_config(self): """加载并解析所有配置""" # 1. 确定工作区根目录 self._determine_workspace_root() # 2. 加载环境变量 self._load_environment_vars() # 3. 加载路径配置 self._load_path_configs() # 4. 加载服务器配置 self._load_server_config() # 5. 加载模型配置 self._load_model_config() # 6. 加载安全配置 self._load_security_config() # 7. 加载数据库配置 self._load_database_config() # 8. 加载生活系统配置 self._load_life_system_config() # 9. 加载硬件配置 self._load_hardware_config() # 10. 记录最终配置 self.log_config() def _determine_workspace_root(self): """确定工作区根目录""" # 优先使用环境变量 env_root = os.getenv('WORKSPACE_ROOT') if env_root: self.WORKSPACE_ROOT = Path(env_root) self.logger.info(f"✅ 使用环境变量指定工作区: {self.WORKSPACE_ROOT}") return # 尝试常见位置 possible_paths = [ Path("E:/AI_System"), Path(__file__).resolve().parent.parent, Path.cwd(), Path.home() / "AI_System" ] for path in possible_paths: if path.exists() and "AI_System" in str(path): self.WORKSPACE_ROOT = path self.logger.info(f"✅ 检测到有效工作区: {self.WORKSPACE_ROOT}") return # 默认使用当前目录 self.WORKSPACE_ROOT = Path.cwd() self.logger.warning(f"⚠️ 未找到AI_System工作区,使用当前目录: {self.WORKSPACE_ROOT}") def _load_environment_vars(self): """加载.env文件中的环境变量""" env_path = self.WORKSPACE_ROOT / '.env' if env_path.exists(): load_dotenv(env_path) self.logger.info(f"✅ 已加载环境文件: {env_path}") else: self.logger.warning(f"⚠️ 未找到环境文件: {env_path} - 使用默认配置") # 基础环境配置 self.ENV = os.getenv('FLASK_ENV', 'development') self.DEBUG_MODE = os.getenv('DEBUG_MODE', 'true').lower() == 'true' def _load_path_configs(self): """加载所有路径配置""" self.PROJECT_ROOT = self.WORKSPACE_ROOT self.AGENT_PATH = self._get_path('AGENT_PATH', 'agent') self.WEB_UI_PATH = self._get_path('WEB_UI_PATH', 'web_ui') self.MODEL_CACHE_DIR = self._get_path('MODEL_CACHE_DIR', 'model_cache') self.LOG_DIR = self._get_path('LOG_DIR', 'logs') self.CONFIG_DIR = self._get_path('CONFIG_DIR', 'config') self.CORE_DIR = self._get_path('CORE_DIR', 'core') self.COGNITIVE_ARCH_DIR = self._get_path('COGNITIVE_ARCH_DIR', 'cognitive_arch') self.ENVIRONMENT_DIR = self._get_path('ENVIRONMENT_DIR', 'environment') self.UTILS_DIR = self._get_path('UTILS_DIR', 'utils') self.RESOURCES_DIR = self._get_path('RESOURCES_DIR', 'resources') def _load_server_config(self): """加载服务器相关配置""" self.HOST = os.getenv('HOST', '0.0.0.0') self.PORT = int(os.getenv('FLASK_PORT', 8000)) self.GRADIO_PORT = int(os.getenv('GRADIO_PORT', 7860)) self.LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper() # 更新日志级别 self.logger.setLevel(getattr(logging, self.LOG_LEVEL, logging.INFO)) def _load_model_config(self): """加载模型相关配置""" self.USE_GPU = os.getenv('USE_GPU', 'false').lower() == 'true' self.DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'minimal-model') self.MODEL_LOAD_TIMEOUT = int(os.getenv('MODEL_LOAD_TIMEOUT', 300)) # 5分钟 self.MAX_MODEL_INSTANCES = int(os.getenv('MAX_MODEL_INSTANCES', 3)) def _load_security_config(self): """加载安全相关配置""" self.SECRET_KEY = os.getenv('SECRET_KEY', 'dev-secret-key') self.API_KEY = os.getenv('API_KEY', '') self.ENABLE_AUTH = os.getenv('ENABLE_AUTH', 'false').lower() == 'true' def _load_database_config(self): """加载数据库配置""" # SQLite默认配置 self.DB_TYPE = os.getenv('DB_TYPE', 'sqlite') self.DB_PATH = self._get_path('DB_PATH', 'environment/environment.db') # 关系型数据库配置 self.DB_HOST = os.getenv('DB_HOST', 'localhost') self.DB_PORT = int(os.getenv('DB_PORT', 3306)) self.DB_NAME = os.getenv('DB_NAME', 'ai_system') self.DB_USER = os.getenv('DB_USER', 'ai_user') self.DB_PASSWORD = os.getenv('DB_PASSWORD', 'secure_password') def _load_life_system_config(self): """加载生活系统配置""" self.LIFE_SYSTEM_ENABLED = os.getenv('LIFE_SYSTEM_ENABLED', 'true').lower() == 'true' self.LIFE_SCHEDULE = { "wake_up": os.getenv('WAKE_UP', '08:00'), "breakfast": os.getenv('BREAKFAST', '08:30'), "lunch": os.getenv('LUNCH', '12:30'), "dinner": os.getenv('DINNER', '19:00'), "sleep": os.getenv('SLEEP', '23:00') } self.ENERGY_THRESHOLD = float(os.getenv('ENERGY_THRESHOLD', 0.3)) def _load_hardware_config(self): """加载硬件配置文件""" hardware_config_path = self.CONFIG_DIR / 'hardware_config.json' self.HARDWARE_CONFIG = {} if hardware_config_path.exists(): try: with open(hardware_config_path, 'r') as f: self.HARDWARE_CONFIG = json.load(f) self.logger.info(f"✅ 已加载硬件配置文件: {hardware_config_path}") except Exception as e: self.logger.error(f"❌ 解析硬件配置文件失败: {str(e)}") else: self.logger.warning(f"⚠️ 未找到硬件配置文件: {hardware_config_path}") def _get_path(self, env_var, default_dir): """获取路径配置,优先使用环境变量""" env_path = os.getenv(env_var) if env_path: path = Path(env_path) if not path.is_absolute(): path = self.WORKSPACE_ROOT / path self.logger.debug(f"📁 {env_var}: {path}") return path # 使用默认目录 path = self.WORKSPACE_ROOT / default_dir self.logger.debug(f"📁 {env_var} (默认): {path}") return path def _ensure_dirs(self): """确保所有必要目录存在""" required_dirs = [ self.LOG_DIR, self.MODEL_CACHE_DIR, self.AGENT_PATH, self.WEB_UI_PATH, self.CONFIG_DIR, self.CORE_DIR, self.COGNITIVE_ARCH_DIR, self.ENVIRONMENT_DIR, self.UTILS_DIR, self.RESOURCES_DIR ] for path in required_dirs: try: path.mkdir(parents=True, exist_ok=True) self.logger.debug(f"✅ 确保目录存在: {path}") except Exception as e: self.logger.error(f"❌ 创建目录失败 {path}: {str(e)}") def log_config(self): """记录关键配置信息""" self.logger.info("=" * 60) self.logger.info(f"🏠 工作区根目录: {self.WORKSPACE_ROOT}") self.logger.info(f"🌿 环境: {self.ENV} (DEBUG: {self.DEBUG_MODE})") self.logger.info(f"🌐 主机: {self.HOST}:{self.PORT} (Gradio: {self.GRADIO_PORT})") self.logger.info(f"📝 日志级别: {self.LOG_LEVEL}") self.logger.info(f"💾 模型缓存: {self.MODEL_CACHE_DIR}") self.logger.info(f"🤖 默认模型: {self.DEFAULT_MODEL} (GPU: {self.USE_GPU})") self.logger.info(f"🔑 安全密钥: {'*****' if self.SECRET_KEY != 'dev-secret-key' else self.SECRET_KEY}") self.logger.info(f"⏰ 生活系统: {'启用' if self.LIFE_SYSTEM_ENABLED else '禁用'}") self.logger.info(f"🗄️ 数据库: {self.DB_TYPE}://{self.DB_USER}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}") self.logger.info("=" * 60) def get_paths(self): """返回所有路径配置""" return { "workspace_root": self.WORKSPACE_ROOT, "agent_path": self.AGENT_PATH, "web_ui_path": self.WEB_UI_PATH, "model_cache_dir": self.MODEL_CACHE_DIR, "log_dir": self.LOG_DIR, "config_dir": self.CONFIG_DIR, "core_dir": self.CORE_DIR, "cognitive_arch_dir": self.COGNITIVE_ARCH_DIR, "environment_dir": self.ENVIRONMENT_DIR, "utils_dir": self.UTILS_DIR, "resources_dir": self.RESOURCES_DIR } def get_db_config(self): """返回数据库配置""" return { "type": self.DB_TYPE, "host": self.DB_HOST, "port": self.DB_PORT, "name": self.DB_NAME, "user": self.DB_USER, "password": self.DB_PASSWORD, "path": self.DB_PATH if self.DB_TYPE == 'sqlite' else None } def get_life_schedule(self): """返回生活系统时间表""" return self.LIFE_SCHEDULE.copy() def get_hardware_config(self): """返回硬件配置""" return self.HARDWARE_CONFIG.copy() # 全局配置实例 config = ConfigLoader() # 测试配置 if __name__ == '__main__': print("=" * 60) print("测试配置加载器") print("=" * 60) # 打印关键配置 print(f"工作区根目录: {config.WORKSPACE_ROOT}") print(f"环境模式: {config.ENV}") print(f"服务器地址: {config.HOST}:{config.PORT}") # 打印所有路径 paths = config.get_paths() print("\n路径配置:") for name, path in paths.items(): print(f" {name.replace('_', ' ').title()}: {path}") # 打印数据库配置 db_config = config.get_db_config() print("\n数据库配置:") for key, value in db_config.items(): if value: # 跳过空值 print(f" {key.replace('_', ' ').title()}: {value}") print("=" * 60) # 全局配置实例 config = ConfigLoader()”

最新推荐

recommend-type

cybrid-api-id-java-v0.18.5-sources.jar

cybrid-api-id-java-v0.18.5-sources.jar
recommend-type

法宝励志app logo 尺寸:600 x 300px

透明度30%
recommend-type

个人作品:使用React和Material-UI打造的赛车主题个人网站

### 知识点概述 该部分将围绕提供的文件信息进行展开,包含React框架、Material-UI库、网站性能优化、版本控制、网站部署以及相关的标签解析等详细知识点。 ### React框架 #### React简介 React是由Facebook开发和维护的一个用于构建用户界面的JavaScript库。它采用组件化的方式,使得开发者可以将UI分解为独立、可复用的组件。这些组件可以包含自己的状态,且只有状态发生变更时,才会重新渲染相应的组件,从而提高应用性能。 #### React应用生命周期 在React中,组件从创建到挂载、更新再到卸载,均遵循一套生命周期方法。例如,`componentDidMount`是在组件挂载后立即调用的方法,常用于执行如数据获取这类操作。`componentDidUpdate`则是组件更新后调用,可用于与当前和之前的props进行比较,并基于比较结果执行更新操作。 ### Material-UI #### Material-UI简介 Material-UI是一个React的用户界面框架,它提供了一整套现成的组件,符合Google的Material Design设计语言。Material-UI的核心优势在于其能够快速实现美观且一致的UI界面,同时保持高度的可定制性。该框架包含各种常用的UI元素,如按钮、输入框、卡片等,并拥有丰富的主题配置选项来支持不同品牌和风格的设计需求。 #### Material-UI中的组件使用 Material-UI通过组件化的方式提供各种UI元素,开发者可以根据需要自由组合和构建界面。例如,`Button`组件可以用于创建按钮,`Card`组件用于创建卡片布局等。每个组件的使用都遵循Material-UI的设计规范,确保界面美观和用户友好。 ### 网站性能优化 #### 响应式设计 从描述中提到网站支持移动和桌面端的定制设计,这是响应式设计的核心特点。响应式设计意味着网页能够根据不同的屏幕尺寸和分辨率,自动调整布局,提供最优化的浏览体验。 #### 动画和过渡效果 网站引入了新的过渡和动画,这不仅提升了用户体验,也可能有助于页面元素间转换时的直观性。使用React可以轻松地添加和管理动画,因为状态更新时React会自动处理组件树的更新。 ### 版本控制和分叉仓库 #### 版本控制(Git) 从描述中提到可以分叉此仓库,这涉及到了Git版本控制工具的使用。Git是一个分布式版本控制系统,用于跟踪代码变更,并且支持协作开发。通过分叉仓库,开发者可以从原始项目创建一个副本,可以在副本上进行修改而不影响原项目。 #### 分叉(Fork)和克隆(Clone) 分叉操作会创建一个远程仓库的副本,而克隆操作则会将远程仓库的内容复制到本地计算机。这意味着开发者可以在本地进行更改,然后选择将更改推送到自己的远程副本,或贡献回原始项目。 ### 安装与部署 #### 安装依赖项 在分叉/克隆仓库之后,开发者需要在项目目录中运行`npm install`来安装所有必需的依赖项。这一步骤是启动开发服务器前的必要准备。 #### 开发和生产环境 开发者需要区分开发环境和生产环境。开发环境通常包含调试工具,而生产环境需要优化和压缩资源以提升网站性能。通过运行`npm start`可以启动开发服务器,而使用`npm run build`则可以构建用于生产环境的静态文件。 ### 标签解析 #### React 标签提及了React,这意味着项目是基于React框架开发的,相关开发者和招聘者可能会搜索此标签以找到具有React开发经验的人才。 #### Blog & personal-website 标签还提到了博客和个人网站,这可能意味着项目是一个个人博客或者展示个人技能和经历的网站,通常用于职业发展或个人品牌建设。 #### Resume-website 标签提到了简历网站,这可以是一个在线简历或个人作品集的展示平台,提供了一个展示个人技能和经验的数字平台,便于求职者与潜在雇主之间沟通。 ### 结语 以上是根据提供的文件信息总结的详细知识点,涵盖了React和Material-UI框架的使用、网站性能优化、版本控制和仓库分叉、安装和部署流程以及对相关标签的解释和应用。这些内容不仅有助于理解该项目的技术细节,而且对于了解构建现代Web应用所涉及的技术和流程也是十分有帮助的。
recommend-type

活动图设计技巧大公开:如何在StarUML中构建高逻辑性的业务流程?

# 摘要 活动图作为软件系统建模的重要工具,能够有效表达复杂业务流程的逻辑结构与行为关系。本文围绕StarUML平台,系统探讨了活动图的设计原理、建模方法及其在实际项目中的应用价值。文章从活动图的核心元素与建模逻辑入手,分析了结构化建模方法与语义表达机制,并详细阐述了在StarUML中构建标准活动图的操作流程与优化策略。针对复杂业务场景,
recommend-type

模拟经典运放电路结构

<think>嗯,用户想了解经典运放电路的结构、模拟方法和实现方式。首先,我需要回忆运放的基本结构。运放通常由差分输入级、中间增益级和输出级组成,对吧?然后,模拟方法可能涉及虚短和虚断的概念,这是分析运放电路的关键。用户提到的引用里也强调了这两个概念的重要性,所以必须详细解释。接下来,我得考虑不同的经典电路结构,比如反向放大器、同向放大器、电压跟随器、加法器、微分器和积分器。每个电路的结构和公式需要准确描述,并确保使用正确的LaTeX格式。例如,反向放大器的公式是$V_o= -\frac{R_f}{R_1}V_i$,要检查是否用$$...$$还是$...$,根据用户的要求,行内公式用$,独立公
recommend-type

MATLAB模拟无线传感器网络与区块链技术

根据给定文件信息,我们将详细探讨以下几个关键知识点: 1. 无线传感器网络(Wireless Sensor Network, WSN): 无线传感器网络是由一组具有传感器、处理单元和通信能力的小型设备组成的网络,这些设备能够相互协作,完成对环境的监测任务。无线传感器网络具有部署便捷、自组织、灵活性高等特点。它在智能交通、环境监测、智能家居等领域有着广泛的应用。 2. 区块链技术(Blockchain Technology): 区块链是一种分布式数据库技术,其特点是去中心化、数据不可篡改、信息透明。在无线传感器网络中,区块链可用于提高数据的可信度和安全性。每个节点生成的块(block)将包含一段时期内的交易信息,这些块链式地连接在一起,形成链状结构,即区块链。通过共识机制(如工作量证明PoW、权益证明PoS等),网络中的节点对数据的有效性达成一致,从而保证数据的安全性和可靠性。 3. 随机泛洪路由技术(Random Flooding Routing): 随机泛洪路由技术是一种无需路由表的简单、基于概率的路由方法。在泛洪机制中,消息从源节点发出后,每个接收到消息的节点都会以一定的概率转发给其邻居节点。该技术易于实现,但可能会导致大量重复传输,进而增加网络的负载和能量消耗。因此,随机泛洪路由通常用于对实时性要求较高,但对能量和资源消耗要求不高的场合。 4. MATLAB仿真: MATLAB是一种高级数学计算和仿真软件,它广泛应用于工程计算、控制系统、信号处理、通信系统等领域。在无线传感器网络和区块链技术的研究中,MATLAB提供了强大的仿真环境和工具箱,使得研究人员能够模拟网络行为、验证算法性能和优化系统设计。 5. 能量效率(Energy Efficiency): 在无线传感器网络的设计中,能量效率是一个核心考量因素。由于传感器节点通常由电池供电,并且电池的更换或充电往往不便或不可行,因此降低节点能耗,延长网络的生命周期至关重要。研究者需要在保证网络性能的同时,采用各种策略来减少节点的能量消耗。 6. 静态节点(Static Node): 在无线传感器网络中,静态节点指的是那些位置固定不动的节点。与移动节点相比,静态节点的网络拓扑结构相对稳定,这有助于简化路由策略的设计,并且在一定程度上提高了系统的可预测性。静态节点适用于那些对位置变化不敏感的监测任务。 7. 节点块生成(Block Generation at Nodes): 在区块链技术中,节点块生成是指每个节点按照一定的规则(如PoW、PoS等)打包一段时间内的交易记录,生成新的数据块,并将其加入到区块链中的过程。每个新生成的块都包含前一个块的哈希值,确保了链的连续性和不可篡改性。在无线传感器网络中,节点生成块的过程也是数据交换的一部分,每个节点在完成数据处理和转发后,可能会产生新的块。 综合以上知识点,我们可以了解到,给定文件中的MATLAB仿真代码是专门用于无线传感器网络环境的仿真,其中实现了随机泛洪路由技术来模拟数据传输过程,并通过节点上生成块的方式构建了区块链。该代码特别适用于静态节点环境,其目的在于研究如何降低能量消耗,并保证数据传输的可靠性和安全性。代码的开源性将有助于研究人员和开发者访问、使用、修改和进一步优化该仿真模型。
recommend-type

UML状态图深度剖析:掌握对象生命周期建模的7个关键要点

# 摘要 UML状态图是描述系统动态行为的核心建模工具,广泛应用于软件与系统设计中。本文系统阐述了状态图的基本概念与理论基础,深入分析了状态、转移、复合结构及并发机制等关键建模元素,并详细探讨了状态图的构建流程与设计原则,强调行为建模的逻辑完整性与可维护性。结合嵌入式系统、业务流程和设计模式等实际应用场景,展示了状态图在复杂系统状态管理中的有效性。同时,本文研究了状态图与类图、序列图的协同机制,探讨了其在系统架构设计中的整合作用,并介绍了主流建模工具对状态图的支持与自动化实现方法,为工程实践提供了理论指导和技术路径。 # 关键字 UML状态图;状态转移;复合状态;并发建模;行为建模;
recommend-type

dbeaver缩小编辑页面

<think> 首先,用户的问题是:“我想查找在DBeaver中缩小编辑页面的方法 请问DBeaver缩小编辑页面的方法”。这似乎是用户在寻求如何在DBeaver数据库工具中缩小编辑页面的方法。 回顾系统级指令: - 所有行内数学表达式使用$...$格式。 - 独立公式使用$$...$$格式并单独成段。 - LaTeX语法必须正确。 - 使用中文回答。 - 生成相关问题。 - 引用的段落末尾自然地添加引用标识。 用户提供的引用: - 引用[1]:关于SQL函数和触发器的更新操作,与DBeaver缩小页面无关。 - 引用[2]:提到搜索历史、模糊查询、列编辑模式、自动换行和截图工具snipa
recommend-type

加强Magento管理员密码强度的扩展工具

Magento是一个流行的开源电子商务平台,它允许商家和开发人员构建和管理在线商店。为了确保在线商店的安全性,管理员密码的强度至关重要。Magento默认提供的管理员密码强度规则对于基本安全需求来说已经不错,但往往可以根据不同的安全需求进行增强。 Magento的“magento-admin-password-strength-enforcer”扩展就是用来加强默认密码策略的工具之一。通过这个扩展,网站管理员可以设置一个更高的密码长度阈值,以强制新创建的管理员密码满足一定的安全标准。 知识点说明: 1. Magento平台概述: Magento是一个基于PHP语言和MySQL数据库开发的电子商务解决方案,它具有模块化架构的特点,提供了丰富的扩展性和定制性。平台内置了许多功能,比如目录管理、营销工具、SEO优化等,同时支持多店铺管理和多种支付方式。 2. 管理员密码安全性: 在电子商务平台中,管理面板的访问权限至关重要。管理员账户通常拥有对网站进行设置、配置和维护的权限,因此密码的安全性直接关系到整个网站的安全性。如果密码强度不够,恶意攻击者就可能通过各种手段获取密码,进而对网站进行非法操作。 3. Magento密码强度策略: Magento默认配置中包含了密码强度的验证规则,如要求密码包含一定数量的字符、数字和特殊符号,以及不得包含用户名等。这些规则在一定程度上增强了密码的安全性,但随着网络攻击手段的不断进步,增强密码策略的要求变得越发重要。 4. Magento扩展与增强: Magento社区和开发人员不断提供各种扩展(Extensions)来增加Magento的功能,包括安全增强功能。magento-admin-password-strength-enforcer扩展就是这类工具之一,它允许管理员通过设置来提高密码强度,尤其是通过增加密码长度的要求。 5. 扩展安装和配置: 安装此类Magento扩展通常涉及解压下载的文件到Magento的特定目录,并通过Magento的后台管理系统进行安装配置。该扩展提供了一个简单的配置界面,管理员可以直接在Magento的后台设置界面(System > Configuration > Admin Password Strength)进行操作,输入想要设置的新密码长度阈值。 6. 特征和优势: - 简洁性:该扩展不会修改Magento的核心代码,这意味着它可以更加安全地应用,不会影响Magento平台的其他部分。 - 向后兼容性:扩展设计为向后兼容,因此可以在不影响其他已安装扩展的情况下使用。 - 灵活性:提供了默认密码长度的设置,管理员可以自行定义密码的最小长度要求,以满足自己的安全需求。 - 升级友好:由于不涉及核心代码,该扩展通常能够随着Magento平台的升级而正常工作。 7. 核心代码与扩展开发: Magento的PHP代码库是其架构的重要组成部分,扩展开发通常需要开发者具备良好的PHP开发能力。通过编写代码来增强或修改Magento的功能,是实现特定需求的常用方法。这种开发方式既保留了Magento的核心优势,又能够灵活地根据用户需求进行调整。 8. 安全维护与最佳实践: 在使用和维护电子商务平台时,应该遵循一些最佳实践,如定期更新Magento和扩展,保持软件处于最新版本以避免已知的安全漏洞。此外,还应定期更换密码,不在不同网站使用相同的密码,以减少被攻击的风险。 通过上述详细的知识点阐述,可以看出magento-admin-password-strength-enforcer扩展是如何通过提高管理员密码的最小长度要求来提升Magento平台的安全性的,同时也强调了在实施此类安全措施时需要注意的最佳实践和操作方法。
recommend-type

时序图绘制实战指南:3步提升可读性,精准表达复杂交互流程

# 摘要 时序图作为UML中描述对象间交互行为的核心工具,在软件设计与系统分析中具有重要价值,但其可读性常因结构复杂、信息过载而受限。本文系统阐述时序图的理论基础,解析参与者、生命线、消息序列等核心要素的语义逻辑,提出聚焦主流程、层次化抽象等建模原则,并深入探讨视觉密度与命名规范对认知负荷的影响。针对实际绘制难点,文章构建了“逻辑精炼—布局优化—