lm_polygraph.stat_calculators.embeddings module

class lm_polygraph.stat_calculators.embeddings.EmbeddingsCalculator(hidden_layer: int = -1)[source]

Bases: StatCalculator

static meta_info() Tuple[List[str], List[str]][source]

Returns the statistics and dependencies for the calculator.

lm_polygraph.stat_calculators.embeddings.aggregate(x, aggregation_method, axis)[source]
lm_polygraph.stat_calculators.embeddings.get_embeddings_from_output(output, batch, model_type, hidden_state: List[str] = ['encoder', 'decoder'], ignore_padding: bool = True, use_averaging: bool = True, all_layers: bool = False, aggregation_method: str = 'mean', level: str = 'sequence', hidden_layer: int = -1)[source]