rex

main logical entrypoint for ReX.

Functions

try_preprocess(args, model_shape, device)

Makes an attempt to preprocess input data as required for the model.

load_and_preprocess_data(model_shape, device, args)

Loads input data from filepath and does preprocessing.

validate_shape(data, model_shape)

predict_target(data, args, prediction_func)

Predicts classification of input data, using given prediction function.

calculate_responsibility(data, args, prediction_func)

Calculates ResponsibilityMaps for input data using given args.

analyze(exp, data_mode)

Analyzes an Explanation.

_explanation(args, model_shape, prediction_func, device)

Takes a CausalArgs object and model information and returns a Explanation.

get_prediction_func_from_args(args)

Takes a CausalArgs object and gets the prediction function and model shape.

explanation(args, device[, db])

Takes a CausalArgs object and returns a Explanation.

Module Contents

rex.try_preprocess(args, model_shape, device)

Makes an attempt to preprocess input data as required for the model.

Data preprocessing is based on file extension and (possibly) user-defined mode. File extensions in [".jpg", ".jpeg", ".png", ".tif", ".tiff"] are treated as images, “.npy” and “.mat” are treated as Numpy arrays, and “.nii” are treated as nifti files. For any other file extension, we create a Data object without pre-processing.

Parameters:
  • args (rex_xai.input.config.CausalArgs) – configuration values for ReX

  • model_shape (Tuple[int]) – shape of the input tensor of the model, as returned by get_prediction_func_from_args()

  • device (torch.device) – as returned by get_device()

Returns:

the processed input data

Return type:

Data

rex.load_and_preprocess_data(model_shape, device, args)

Loads input data from filepath and does preprocessing.

Uses a custom preprocesssing function if this is defined in args.script.preprocess, otherwise try_preprocess().

Parameters:
  • model_shape (Tuple[int]) – shape of the input tensor of the model, as returned by get_prediction_func_from_args()

  • device (torch.device) – as returned by get_device()

  • args (rex_xai.input.config.CausalArgs) – configuration values for ReX

Returns:

the processed input data

Return type:

Data

rex.validate_shape(data, model_shape)
Parameters:

data (rex_xai.input.input_data.Data)

Return type:

rex_xai.input.input_data.Data

rex.predict_target(data, args, prediction_func)

Predicts classification of input data, using given prediction function.

Uses prediction_func to identify the classification of the input data and return this as the target classification for ReX.

Parameters:
  • data (rex_xai.input.input_data.Data) – processed input data object

  • args (rex_xai.input.config.CausalArgs) – configuration values for ReX

  • prediction_func – prediction function for the model

Returns:

the predicted target classification and confidence

Return type:

Prediction

rex.calculate_responsibility(data, args, prediction_func, custom_height=None, custom_width=None)

Calculates ResponsibilityMaps for input data using given args.

Runs causal_explanation() for args.iters iterations, and returns a ResponsibilityMaps object and a dictionary containing some statistics about the calculation process. The ResponsibilityMaps object by default only includes the responsibility map that matches the classification of the input data. Set keep_all_maps to True to retain all maps.

Parameters:
  • data (rex_xai.input.input_data.Data) – processed input data object

  • args (rex_xai.input.config.CausalArgs) – configuration values for ReX

  • prediction_func – prediction function for the model

  • keep_all_maps – whether to retain all ResponsibilityMaps, or just the one that matches the target classification.

Returns:

tuple containing

  • ResponsibilityMaps: ResponsibilityMaps for the given data, prediction function, and args.

  • dict: statistics for the call of this function that generated the ResponsibilityMaps object

Return type:

tuple[rex_xai.responsibility.resp_maps.ResponsibilityMaps, dict]

rex.analyze(exp, data_mode)

Analyzes an Explanation.

Analyzes the area ratio, entropy difference, insertion and deletion curves for an Explanation object, prints them, and returns them.

Parameters:
  • exp (rex_xai.explanation.explanation.Explanation) – Explanation object as returned by _explanation()

  • data_mode (str | None) – Mode of the input data. Entropy of the responsibility map calculated if data_mode is “RGB”. If ``data_mode’’ is ``spectral’’ then spectral entropy is calculated.

Returns:

tuple containing

  • area (float)

  • entropy (float)

  • insertion_curve (float)

  • deletion_curve (float)

Return type:

Dict[str, float]

rex._explanation(args, model_shape, prediction_func, device, db=None, path=None)

Takes a CausalArgs object and model information and returns a Explanation.

Takes a CausalArgs object, model shape and prediction function and returns an Explanation. Depending on the input args, optionally produces output plots, analyses the output explanation, and/or writes results to a database.

Parameters:
  • args (rex_xai.input.config.CausalArgs) – configuration values for ReX

  • model_shape (Tuple[int]) – shape of the input tensor of the model, as returned by get_prediction_func_from_args()

  • prediction_func (Callable) – as returned by get_prediction_func_from_args()

  • device (torch.device) – as returned by get_device()

  • db (sqlalchemy.orm.Session | None) – None or as returned by initialise_rex_db()

  • path (str | None)

Returns:

An Explanation object containing the causal reponsibility explanation calculated using the given args.

Return type:

Explanation

rex.get_prediction_func_from_args(args)

Takes a CausalArgs object and gets the prediction function and model shape.

If args.script specifies a prediction function and model shape, returns these. Otherwise gets the prediction function and model shape from the provided model file.

Parameters:

args (rex_xai.input.config.CausalArgs) – configuration values for ReX

Returns:

tuple containing

  • prediction_func

  • model_shape

Raises:

RuntimeError – if an onnx inference instance cannot be created from the provided model file.

rex.explanation(args, device, db=None)

Takes a CausalArgs object and returns a Explanation.

Takes a CausalArgs object and returns either an Explanation, or a list of Explanations if the input args.path is a directory rather than a path to a single file.

Parameters:
  • args (rex_xai.input.config.CausalArgs) – configuration values for ReX

  • device (torch.device) – as returned by get_device()

  • db (sqlalchemy.orm.Session | None) – None or as returned by initialise_rex_db()

Returns:

An Explanation object containing the causal reponsibility explanation calculated using the given args.

Return type:

Explanation