@@ -250,19 +250,25 @@ def jsonify(self):
250
250
251
251
class GPUStatCollection (object ):
252
252
253
- def __init__ (self , gpu_list ):
253
+ def __init__ (self , gpu_list , driver_version = None ):
254
254
self .gpus = gpu_list
255
255
256
256
# attach additional system information
257
257
self .hostname = platform .node ()
258
258
self .query_time = datetime .now ()
259
+ self .driver_version = driver_version
259
260
260
261
@staticmethod
261
262
def new_query ():
262
263
"""Query the information of all the GPUs on local machine"""
263
264
264
265
N .nvmlInit ()
265
266
267
+ def _decode (b ):
268
+ if isinstance (b , bytes ):
269
+ return b .decode () # for python3, to unicode
270
+ return b
271
+
266
272
def get_gpu_info (handle ):
267
273
"""Get one GPU information specified by nvml handle"""
268
274
@@ -284,11 +290,6 @@ def get_process_info(nv_process):
284
290
process ['pid' ] = nv_process .pid
285
291
return process
286
292
287
- def _decode (b ):
288
- if isinstance (b , bytes ):
289
- return b .decode () # for python3, to unicode
290
- return b
291
-
292
293
name = _decode (N .nvmlDeviceGetName (handle ))
293
294
uuid = _decode (N .nvmlDeviceGetUUID (handle ))
294
295
@@ -374,8 +375,14 @@ def _decode(b):
374
375
gpu_stat = GPUStat (gpu_info )
375
376
gpu_list .append (gpu_stat )
376
377
378
+ # 2. additional info (driver version, etc).
379
+ try :
380
+ driver_version = _decode (N .nvmlSystemGetDriverVersion ())
381
+ except N .NVMLError :
382
+ driver_version = None # N/A
383
+
377
384
N .nvmlShutdown ()
378
- return GPUStatCollection (gpu_list )
385
+ return GPUStatCollection (gpu_list , driver_version = driver_version )
379
386
380
387
def __len__ (self ):
381
388
return len (self .gpus )
@@ -424,15 +431,19 @@ def print_formatted(self, fp=sys.stdout, force_color=False, no_color=False,
424
431
if show_header :
425
432
time_format = locale .nl_langinfo (locale .D_T_FMT )
426
433
427
- header_template = '{t.bold_white}{hostname:{width}}{t.normal} {timestr}' # noqa: E501
434
+ header_template = '{t.bold_white}{hostname:{width}}{t.normal} '
435
+ header_template += '{timestr} '
436
+ header_template += '{t.bold_black}{driver_version}{t.normal}'
437
+
428
438
header_msg = header_template .format (
429
439
hostname = self .hostname ,
430
440
width = gpuname_width + 3 , # len("[?]")
431
441
timestr = self .query_time .strftime (time_format ),
442
+ driver_version = self .driver_version ,
432
443
t = t_color ,
433
444
)
434
445
435
- fp .write (header_msg )
446
+ fp .write (header_msg . strip () )
436
447
fp .write (eol_char )
437
448
438
449
# body
0 commit comments