composer.loggers.in_memory_logger#
Logs metrics to dictionary objects that persist in memory throughout training.
Useful for collecting and plotting data inside notebooks.
Classes
Logs metrics to dictionary objects that persist in memory throughout training. |
- class composer.loggers.in_memory_logger.InMemoryLogger(log_level=LogLevel.BATCH)[source]#
Bases:
composer.loggers.logger_destination.LoggerDestination
Logs metrics to dictionary objects that persist in memory throughout training.
Useful for collecting and plotting data inside notebooks.
- Example usage:
from composer.loggers import InMemoryLogger, LogLevel from composer.trainer import Trainer logger = InMemoryLogger( log_level=LogLevel.BATCH ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", optimizers=[optimizer], loggers=[logger] ) # Get data from logger. If you are using multiple loggers, be sure to confirm # which index in trainer.logger.destinations contains your desired logger. logged_data = trainer.logger.destinations[0].data
- Parameters
log_level (str | LogLevel, optional) โ
LogLevel
(i.e. unit of resolution) at which to record. Defaults toBATCH
, which records everything.
- data#
Mapping of a logged key to a (
Timestamp
,LogLevel
, logged value) tuple. This dictionary contains all logged data.
- get_timeseries(metric)[source]#
Returns logged data as dict containing values of a desired metric over time.
- Parameters
metric (str) โ Metric of interest. Must be present in self.data.keys().
- Returns
timeseries (Dict[str, Any]) โ Dictionary in which one key is
metric
, and the associated value is a list of values of that metric. The remaining keys are each a unit of time, and the associated values are each a list of values of that time unit for the corresponding index of the metric. For example: >>> InMemoryLogger.get_timeseries(metric=โaccuracy/valโ) {โaccuracy/valโ: [31.2, 45.6, 59.3, 64.7, โepochโ: [1, 2, 3, 4, โฆ], โฆ], โbatchโ: [49, 98, 147, 196, โฆ], โฆ}
Example
import matplotlib.pyplot as plt from composer.loggers import InMemoryLogger, LogLevel from composer.core.time import Time, Timestamp in_mem_logger = InMemoryLogger(LogLevel.BATCH) # Populate the logger with data for b in range(0,3): datapoint = b * 3 in_mem_logger.log_data(state=state, log_level=LogLevel.BATCH, data={"accuracy/val": datapoint}) timeseries = in_mem_logger.get_timeseries("accuracy/val") plt.plot(timeseries["batch"], timeseries["accuracy/val"]) plt.xlabel("Batch") plt.ylabel("Validation Accuracy")