@@ -1428,24 +1428,12 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
14281428
14291429 description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
14301430 print (secondary_training_status_message (description , None ), end = "" )
1431- instance_count = description ["ResourceConfig" ]["InstanceCount" ]
1432- status = description ["TrainingJobStatus" ]
1433-
1434- stream_names = [] # The list of log streams
1435- positions = {} # The current position in each stream, map of stream name -> position
1436-
1437- # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1438- # to be interrupted by a transient exception.
1439- config = botocore .config .Config (retries = {"max_attempts" : 15 })
1440- client = self .boto_session .client ("logs" , config = config )
1441- log_group = "/aws/sagemaker/TrainingJobs"
14421431
1443- job_already_completed = status in ("Completed" , "Failed" , "Stopped" )
1444-
1445- state = LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1446- dot = False
1432+ instance_count , stream_names , positions , client , log_group , dot , color_wrap = _logs_init (
1433+ self , description , job = "Training"
1434+ )
14471435
1448- color_wrap = sagemaker . logs . ColorWrap ( )
1436+ state = _get_initial_job_state ( description , "TrainingJobStatus" , wait )
14491437
14501438 # The loop below implements a state machine that alternates between checking the job status
14511439 # and reading whatever is available in the logs at this point. Note, that if we were
@@ -1470,52 +1458,16 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
14701458 last_describe_job_call = time .time ()
14711459 last_description = description
14721460 while True :
1473- if len (stream_names ) < instance_count :
1474- # Log streams are created whenever a container starts writing to stdout/err, so
1475- # this list # may be dynamic until we have a stream for every instance.
1476- try :
1477- streams = client .describe_log_streams (
1478- logGroupName = log_group ,
1479- logStreamNamePrefix = job_name + "/" ,
1480- orderBy = "LogStreamName" ,
1481- limit = instance_count ,
1482- )
1483- stream_names = [s ["logStreamName" ] for s in streams ["logStreams" ]]
1484- positions .update (
1485- [
1486- (s , sagemaker .logs .Position (timestamp = 0 , skip = 0 ))
1487- for s in stream_names
1488- if s not in positions
1489- ]
1490- )
1491- except ClientError as e :
1492- # On the very first training job run on an account, there's no log group until
1493- # the container starts logging, so ignore any errors thrown about that
1494- err = e .response .get ("Error" , {})
1495- if err .get ("Code" , None ) != "ResourceNotFoundException" :
1496- raise
1497-
1498- if len (stream_names ) > 0 :
1499- if dot :
1500- print ("" )
1501- dot = False
1502- for idx , event in sagemaker .logs .multi_stream_iter (
1503- client , log_group , stream_names , positions
1504- ):
1505- color_wrap (idx , event ["message" ])
1506- ts , count = positions [stream_names [idx ]]
1507- if event ["timestamp" ] == ts :
1508- positions [stream_names [idx ]] = sagemaker .logs .Position (
1509- timestamp = ts , skip = count + 1
1510- )
1511- else :
1512- positions [stream_names [idx ]] = sagemaker .logs .Position (
1513- timestamp = event ["timestamp" ], skip = 1
1514- )
1515- else :
1516- dot = True
1517- print ("." , end = "" )
1518- sys .stdout .flush ()
1461+ _flush_log_streams (
1462+ stream_names ,
1463+ instance_count ,
1464+ client ,
1465+ log_group ,
1466+ job_name ,
1467+ positions ,
1468+ dot ,
1469+ color_wrap ,
1470+ )
15191471 if state == LogState .COMPLETE :
15201472 break
15211473
@@ -1554,6 +1506,86 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
15541506 saving = (1 - float (billable_time ) / training_time ) * 100
15551507 print ("Managed Spot Training savings: {:.1f}%" .format (saving ))
15561508
1509+ def logs_for_transform_job (self , job_name , wait = False , poll = 10 ):
1510+ """Display the logs for a given transform job, optionally tailing them until the
1511+ job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
1512+ based on which instance the log entry is from.
1513+
1514+ Args:
1515+ job_name (str): Name of the transform job to display the logs for.
1516+ wait (bool): Whether to keep looking for new log entries until the job completes
1517+ (default: False).
1518+ poll (int): The interval in seconds between polling for new log entries and job
1519+ completion (default: 5).
1520+
1521+ Raises:
1522+ ValueError: If the transform job fails.
1523+ """
1524+
1525+ description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1526+
1527+ instance_count , stream_names , positions , client , log_group , dot , color_wrap = _logs_init (
1528+ self , description , job = "Transform"
1529+ )
1530+
1531+ state = _get_initial_job_state (description , "TransformJobStatus" , wait )
1532+
1533+ # The loop below implements a state machine that alternates between checking the job status
1534+ # and reading whatever is available in the logs at this point. Note, that if we were
1535+ # called with wait == False, we never check the job status.
1536+ #
1537+ # If wait == TRUE and job is not completed, the initial state is TAILING
1538+ # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
1539+ # complete).
1540+ #
1541+ # The state table:
1542+ #
1543+ # STATE ACTIONS CONDITION NEW STATE
1544+ # ---------------- ---------------- ----------------- ----------------
1545+ # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
1546+ # Else TAILING
1547+ # JOB_COMPLETE Read logs, Pause Any COMPLETE
1548+ # COMPLETE Read logs, Exit N/A
1549+ #
1550+ # Notes:
1551+ # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
1552+ # Cloudwatch after the job was marked complete.
1553+ last_describe_job_call = time .time ()
1554+ while True :
1555+ _flush_log_streams (
1556+ stream_names ,
1557+ instance_count ,
1558+ client ,
1559+ log_group ,
1560+ job_name ,
1561+ positions ,
1562+ dot ,
1563+ color_wrap ,
1564+ )
1565+ if state == LogState .COMPLETE :
1566+ break
1567+
1568+ time .sleep (poll )
1569+
1570+ if state == LogState .JOB_COMPLETE :
1571+ state = LogState .COMPLETE
1572+ elif time .time () - last_describe_job_call >= 30 :
1573+ description = self .sagemaker_client .describe_transform_job (
1574+ TransformJobName = job_name
1575+ )
1576+ last_describe_job_call = time .time ()
1577+
1578+ status = description ["TransformJobStatus" ]
1579+
1580+ if status in ("Completed" , "Failed" , "Stopped" ):
1581+ print ()
1582+ state = LogState .JOB_COMPLETE
1583+
1584+ if wait :
1585+ self ._check_job_status (job_name , description , "TransformJobStatus" )
1586+ if dot :
1587+ print ()
1588+
15571589
15581590def container_def (image , model_data_url = None , env = None ):
15591591 """Create a definition for executing a container as part of a SageMaker model.
@@ -1892,3 +1924,83 @@ def _vpc_config_from_training_job(
18921924 if vpc_config_override is vpc_utils .VPC_CONFIG_DEFAULT :
18931925 return training_job_desc .get (vpc_utils .VPC_CONFIG_KEY )
18941926 return vpc_utils .sanitize (vpc_config_override )
1927+
1928+
1929+ def _get_initial_job_state (description , status_key , wait ):
1930+ """Placeholder docstring"""
1931+ status = description [status_key ]
1932+ job_already_completed = status in ("Completed" , "Failed" , "Stopped" )
1933+ return LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1934+
1935+
1936+ def _logs_init (sagemaker_session , description , job ):
1937+ """Placeholder docstring"""
1938+ if job == "Training" :
1939+ instance_count = description ["ResourceConfig" ]["InstanceCount" ]
1940+ elif job == "Transform" :
1941+ instance_count = description ["TransformResources" ]["InstanceCount" ]
1942+
1943+ stream_names = [] # The list of log streams
1944+ positions = {} # The current position in each stream, map of stream name -> position
1945+
1946+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1947+ # to be interrupted by a transient exception.
1948+ config = botocore .config .Config (retries = {"max_attempts" : 15 })
1949+ client = sagemaker_session .boto_session .client ("logs" , config = config )
1950+ log_group = "/aws/sagemaker/" + job + "Jobs"
1951+
1952+ dot = False
1953+
1954+ color_wrap = sagemaker .logs .ColorWrap ()
1955+
1956+ return instance_count , stream_names , positions , client , log_group , dot , color_wrap
1957+
1958+
1959+ def _flush_log_streams (
1960+ stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap
1961+ ):
1962+ """Placeholder docstring"""
1963+ if len (stream_names ) < instance_count :
1964+ # Log streams are created whenever a container starts writing to stdout/err, so this list
1965+ # may be dynamic until we have a stream for every instance.
1966+ try :
1967+ streams = client .describe_log_streams (
1968+ logGroupName = log_group ,
1969+ logStreamNamePrefix = job_name + "/" ,
1970+ orderBy = "LogStreamName" ,
1971+ limit = instance_count ,
1972+ )
1973+ stream_names = [s ["logStreamName" ] for s in streams ["logStreams" ]]
1974+ positions .update (
1975+ [
1976+ (s , sagemaker .logs .Position (timestamp = 0 , skip = 0 ))
1977+ for s in stream_names
1978+ if s not in positions
1979+ ]
1980+ )
1981+ except ClientError as e :
1982+ # On the very first training job run on an account, there's no log group until
1983+ # the container starts logging, so ignore any errors thrown about that
1984+ err = e .response .get ("Error" , {})
1985+ if err .get ("Code" , None ) != "ResourceNotFoundException" :
1986+ raise
1987+
1988+ if len (stream_names ) > 0 :
1989+ if dot :
1990+ print ("" )
1991+ dot = False
1992+ for idx , event in sagemaker .logs .multi_stream_iter (
1993+ client , log_group , stream_names , positions
1994+ ):
1995+ color_wrap (idx , event ["message" ])
1996+ ts , count = positions [stream_names [idx ]]
1997+ if event ["timestamp" ] == ts :
1998+ positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = ts , skip = count + 1 )
1999+ else :
2000+ positions [stream_names [idx ]] = sagemaker .logs .Position (
2001+ timestamp = event ["timestamp" ], skip = 1
2002+ )
2003+ else :
2004+ dot = True
2005+ print ("." , end = "" )
2006+ sys .stdout .flush ()
0 commit comments