refactor(amu): refactor enablement and context switching

This change represents a general refactoring to clean up old code that
has been adapted to account for changes required to enable dynamic
auxiliary counters.

Change-Id: Ia85e0518f3f65c765f07b34b67744fc869b9070d
Signed-off-by: Chris Kay <chris.kay@arm.com>
diff --git a/lib/extensions/amu/aarch32/amu.c b/lib/extensions/amu/aarch32/amu.c
index 8948798..e92b9f1 100644
--- a/lib/extensions/amu/aarch32/amu.c
+++ b/lib/extensions/amu/aarch32/amu.c
@@ -8,16 +8,35 @@
 #include <cdefs.h>
 #include <stdbool.h>
 
+#include "../amu_private.h"
 #include <arch.h>
 #include <arch_helpers.h>
-
 #include <lib/el3_runtime/pubsub_events.h>
 #include <lib/extensions/amu.h>
-#include <lib/extensions/amu_private.h>
 
 #include <plat/common/platform.h>
 
+struct amu_ctx {
+	uint64_t group0_cnts[AMU_GROUP0_MAX_COUNTERS];
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	uint64_t group1_cnts[AMU_GROUP1_MAX_COUNTERS];
+#endif
+
+	uint16_t group0_enable;
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	uint16_t group1_enable;
+#endif
+};
+
+static struct amu_ctx amu_ctxs_[PLATFORM_CORE_COUNT];
+
+CASSERT((sizeof(amu_ctxs_[0].group0_enable) * CHAR_BIT) <= AMU_GROUP0_MAX_COUNTERS,
+	amu_ctx_group0_enable_cannot_represent_all_group0_counters);
+
-static struct amu_ctx amu_ctxs[PLATFORM_CORE_COUNT];
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+CASSERT((sizeof(amu_ctxs_[0].group1_enable) * CHAR_BIT) <= AMU_GROUP1_MAX_COUNTERS,
+	amu_ctx_group1_enable_cannot_represent_all_group1_counters);
+#endif
 
 static inline __unused uint32_t read_id_pfr0_amu(void)
 {
@@ -109,53 +128,72 @@
 	write_amcntenclr1(value);
 }
 
-static bool amu_supported(void)
+static __unused bool amu_supported(void)
 {
 	return read_id_pfr0_amu() >= ID_PFR0_AMU_V1;
 }
 
-static bool amu_v1p1_supported(void)
-{
-	return read_id_pfr0_amu() >= ID_PFR0_AMU_V1P1;
-}
-
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-static bool amu_group1_supported(void)
+static __unused bool amu_group1_supported(void)
 {
 	return read_amcfgr_ncg() > 0U;
 }
 #endif
 
 /*
- * Enable counters. This function is meant to be invoked
- * by the context management library before exiting from EL3.
+ * Enable counters. This function is meant to be invoked by the context
+ * management library before exiting from EL3.
  */
 void amu_enable(bool el2_unused)
 {
-	if (!amu_supported()) {
+	uint32_t id_pfr0_amu;		/* AMU version */
+
+	uint32_t amcfgr_ncg;		/* Number of counter groups */
+	uint32_t amcgcr_cg0nc;		/* Number of group 0 counters */
+
+	uint32_t amcntenset0_px = 0x0;	/* Group 0 enable mask */
+	uint32_t amcntenset1_px = 0x0;	/* Group 1 enable mask */
+
+	id_pfr0_amu = read_id_pfr0_amu();
+	if (id_pfr0_amu == ID_PFR0_AMU_NOT_SUPPORTED) {
+		/*
+		 * If the AMU is unsupported, nothing needs to be done.
+		 */
+
 		return;
 	}
 
 	if (el2_unused) {
 		/*
-		 * Non-secure access from EL0 or EL1 to the Activity Monitor
-		 * registers do not trap to EL2.
+		 * HCPTR.TAM: Set to zero so any accesses to the Activity
+		 * Monitor registers do not trap to EL2.
 		 */
 		write_hcptr_tam(0U);
 	}
 
-	/* Enable group 0 counters */
-	write_amcntenset0_px((UINT32_C(1) << read_amcgcr_cg0nc()) - 1U);
+	/*
+	 * Retrieve the number of architected counters. All of these counters
+	 * are enabled by default.
+	 */
 
-#if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		/* Enable group 1 counters */
-		write_amcntenset1_px(AMU_GROUP1_COUNTERS_MASK);
+	amcgcr_cg0nc = read_amcgcr_cg0nc();
+	amcntenset0_px = (UINT32_C(1) << (amcgcr_cg0nc)) - 1U;
+
+	assert(amcgcr_cg0nc <= AMU_AMCGCR_CG0NC_MAX);
+
+	/*
+	 * Enable the requested counters.
+	 */
+
+	write_amcntenset0_px(amcntenset0_px);
+
+	amcfgr_ncg = read_amcfgr_ncg();
+	if (amcfgr_ncg > 0U) {
+		write_amcntenset1_px(amcntenset1_px);
 	}
-#endif
 
 	/* Initialize FEAT_AMUv1p1 features if present. */
-	if (!amu_v1p1_supported()) {
+	if (id_pfr0_amu < ID_PFR0_AMU_V1P1) {
 		return;
 	}
 
@@ -218,49 +256,61 @@
 
 static void *amu_context_save(const void *arg)
 {
-	struct amu_ctx *ctx = &amu_ctxs[plat_my_core_pos()];
-	unsigned int i;
+	uint32_t i;
 
-	if (!amu_supported()) {
-		return (void *)-1;
-	}
+	unsigned int core_pos;
+	struct amu_ctx *ctx;
 
-	/* Assert that group 0/1 counter configuration is what we expect */
-	assert(read_amcntenset0_px() ==
-		((UINT32_C(1) << read_amcgcr_cg0nc()) - 1U));
+	uint32_t id_pfr0_amu;	/* AMU version */
+	uint32_t amcgcr_cg0nc;	/* Number of group 0 counters */
 
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		assert(read_amcntenset1_px() == AMU_GROUP1_COUNTERS_MASK);
+	uint32_t amcfgr_ncg;	/* Number of counter groups */
+	uint32_t amcgcr_cg1nc;	/* Number of group 1 counters */
+#endif
+
+	id_pfr0_amu = read_id_pfr0_amu();
+	if (id_pfr0_amu == ID_PFR0_AMU_NOT_SUPPORTED) {
+		return (void *)0;
 	}
+
+	core_pos = plat_my_core_pos();
+	ctx = &amu_ctxs_[core_pos];
+
+	amcgcr_cg0nc = read_amcgcr_cg0nc();
+
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	amcfgr_ncg = read_amcfgr_ncg();
+	amcgcr_cg1nc = (amcfgr_ncg > 0U) ? read_amcgcr_cg1nc() : 0U;
 #endif
+
 	/*
-	 * Disable group 0/1 counters to avoid other observers like SCP sampling
-	 * counter values from the future via the memory mapped view.
+	 * Disable all AMU counters.
 	 */
-	write_amcntenclr0_px((UINT32_C(1) << read_amcgcr_cg0nc()) - 1U);
+
+	ctx->group0_enable = read_amcntenset0_px();
+	write_amcntenclr0_px(ctx->group0_enable);
 
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		write_amcntenclr1_px(AMU_GROUP1_COUNTERS_MASK);
+	if (amcfgr_ncg > 0U) {
+		ctx->group1_enable = read_amcntenset1_px();
+		write_amcntenclr1_px(ctx->group1_enable);
 	}
 #endif
 
-	isb();
+	/*
+	 * Save the counters to the local context.
+	 */
 
-	/* Save all group 0 counters */
-	for (i = 0U; i < read_amcgcr_cg0nc(); i++) {
+	isb(); /* Ensure counters have been stopped */
+
+	for (i = 0U; i < amcgcr_cg0nc; i++) {
 		ctx->group0_cnts[i] = amu_group0_cnt_read(i);
 	}
 
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		/* Save group 1 counters */
-		for (i = 0U; i < read_amcgcr_cg1nc(); i++) {
-			if ((AMU_GROUP1_COUNTERS_MASK & (1U << i)) != 0U) {
-				ctx->group1_cnts[i] = amu_group1_cnt_read(i);
-			}
-		}
+	for (i = 0U; i < amcgcr_cg1nc; i++) {
+		ctx->group1_cnts[i] = amu_group1_cnt_read(i);
 	}
 #endif
 
@@ -269,41 +319,69 @@
 
 static void *amu_context_restore(const void *arg)
 {
-	struct amu_ctx *ctx = &amu_ctxs[plat_my_core_pos()];
-	unsigned int i;
+	uint32_t i;
+
+	unsigned int core_pos;
+	struct amu_ctx *ctx;
 
-	if (!amu_supported()) {
-		return (void *)-1;
+	uint32_t id_pfr0_amu;	/* AMU version */
+
+	uint32_t amcfgr_ncg;	/* Number of counter groups */
+	uint32_t amcgcr_cg0nc;	/* Number of group 0 counters */
+
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	uint32_t amcgcr_cg1nc;	/* Number of group 1 counters */
+#endif
+
+	id_pfr0_amu = read_id_pfr0_amu();
+	if (id_pfr0_amu == ID_PFR0_AMU_NOT_SUPPORTED) {
+		return (void *)0;
 	}
 
-	/* Counters were disabled in `amu_context_save()` */
-	assert(read_amcntenset0_px() == 0U);
+	core_pos = plat_my_core_pos();
+	ctx = &amu_ctxs_[core_pos];
+
+	amcfgr_ncg = read_amcfgr_ncg();
+	amcgcr_cg0nc = read_amcgcr_cg0nc();
 
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
+	amcgcr_cg1nc = (amcfgr_ncg > 0U) ? read_amcgcr_cg1nc() : 0U;
+#endif
+
+	/*
+	 * Sanity check that all counters were disabled when the context was
+	 * previously saved.
+	 */
+
+	assert(read_amcntenset0_px() == 0U);
+
+	if (amcfgr_ncg > 0U) {
 		assert(read_amcntenset1_px() == 0U);
 	}
-#endif
+
+	/*
+	 * Restore the counter values from the local context.
+	 */
 
-	/* Restore all group 0 counters */
-	for (i = 0U; i < read_amcgcr_cg0nc(); i++) {
+	for (i = 0U; i < amcgcr_cg0nc; i++) {
 		amu_group0_cnt_write(i, ctx->group0_cnts[i]);
 	}
 
-	/* Restore group 0 counter configuration */
-	write_amcntenset0_px((UINT32_C(1) << read_amcgcr_cg0nc()) - 1U);
-
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		/* Restore group 1 counters */
-		for (i = 0U; i < read_amcgcr_cg1nc(); i++) {
-			if ((AMU_GROUP1_COUNTERS_MASK & (1U << i)) != 0U) {
-				amu_group1_cnt_write(i, ctx->group1_cnts[i]);
-			}
-		}
+	for (i = 0U; i < amcgcr_cg1nc; i++) {
+		amu_group1_cnt_write(i, ctx->group1_cnts[i]);
+	}
+#endif
+
+	/*
+	 * Re-enable counters that were disabled during context save.
+	 */
 
-		/* Restore group 1 counter configuration */
-		write_amcntenset1_px(AMU_GROUP1_COUNTERS_MASK);
+	write_amcntenset0_px(ctx->group0_enable);
+
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	if (amcfgr_ncg > 0U) {
+		write_amcntenset1_px(ctx->group1_enable);
 	}
 #endif
 
diff --git a/lib/extensions/amu/aarch64/amu.c b/lib/extensions/amu/aarch64/amu.c
index 6bed6c3..58094ae 100644
--- a/lib/extensions/amu/aarch64/amu.c
+++ b/lib/extensions/amu/aarch64/amu.c
@@ -8,17 +8,42 @@
 #include <cdefs.h>
 #include <stdbool.h>
 
+#include "../amu_private.h"
 #include <arch.h>
 #include <arch_features.h>
 #include <arch_helpers.h>
-
 #include <lib/el3_runtime/pubsub_events.h>
 #include <lib/extensions/amu.h>
-#include <lib/extensions/amu_private.h>
 
 #include <plat/common/platform.h>
 
-static struct amu_ctx amu_ctxs[PLATFORM_CORE_COUNT];
+struct amu_ctx {
+	uint64_t group0_cnts[AMU_GROUP0_MAX_COUNTERS];
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	uint64_t group1_cnts[AMU_GROUP1_MAX_COUNTERS];
+#endif
+
+	/* Architected event counter 1 does not have an offset register */
+	uint64_t group0_voffsets[AMU_GROUP0_MAX_COUNTERS - 1U];
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	uint64_t group1_voffsets[AMU_GROUP1_MAX_COUNTERS];
+#endif
+
+	uint16_t group0_enable;
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	uint16_t group1_enable;
+#endif
+};
+
+static struct amu_ctx amu_ctxs_[PLATFORM_CORE_COUNT];
+
+CASSERT((sizeof(amu_ctxs_[0].group0_enable) * CHAR_BIT) <= AMU_GROUP0_MAX_COUNTERS,
+	amu_ctx_group0_enable_cannot_represent_all_group0_counters);
+
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+CASSERT((sizeof(amu_ctxs_[0].group1_enable) * CHAR_BIT) <= AMU_GROUP1_MAX_COUNTERS,
+	amu_ctx_group1_enable_cannot_represent_all_group1_counters);
+#endif
 
 static inline __unused uint64_t read_id_aa64pfr0_el1_amu(void)
 {
@@ -66,7 +91,7 @@
 		AMCFGR_EL0_NCG_MASK;
 }
 
-static inline uint64_t read_amcgcr_el0_cg0nc(void)
+static inline __unused uint64_t read_amcgcr_el0_cg0nc(void)
 {
 	return (read_amcgcr_el0() >> AMCGCR_EL0_CG0NC_SHIFT) &
 		AMCGCR_EL0_CG0NC_MASK;
@@ -136,37 +161,50 @@
 	write_amcntenclr1_el0(value);
 }
 
-static bool amu_supported(void)
+static __unused bool amu_supported(void)
 {
 	return read_id_aa64pfr0_el1_amu() >= ID_AA64PFR0_AMU_V1;
 }
 
-static bool amu_v1p1_supported(void)
+static __unused bool amu_v1p1_supported(void)
 {
 	return read_id_aa64pfr0_el1_amu() >= ID_AA64PFR0_AMU_V1P1;
 }
 
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-static bool amu_group1_supported(void)
+static __unused bool amu_group1_supported(void)
 {
 	return read_amcfgr_el0_ncg() > 0U;
 }
 #endif
 
 /*
- * Enable counters. This function is meant to be invoked
- * by the context management library before exiting from EL3.
+ * Enable counters. This function is meant to be invoked by the context
+ * management library before exiting from EL3.
  */
 void amu_enable(bool el2_unused, cpu_context_t *ctx)
 {
-	if (!amu_supported()) {
+	uint64_t id_aa64pfr0_el1_amu;		/* AMU version */
+
+	uint64_t amcfgr_el0_ncg;		/* Number of counter groups */
+	uint64_t amcgcr_el0_cg0nc;		/* Number of group 0 counters */
+
+	uint64_t amcntenset0_el0_px = 0x0;	/* Group 0 enable mask */
+	uint64_t amcntenset1_el0_px = 0x0;	/* Group 1 enable mask */
+
+	id_aa64pfr0_el1_amu = read_id_aa64pfr0_el1_amu();
+	if (id_aa64pfr0_el1_amu == ID_AA64PFR0_AMU_NOT_SUPPORTED) {
+		/*
+		 * If the AMU is unsupported, nothing needs to be done.
+		 */
+
 		return;
 	}
 
 	if (el2_unused) {
 		/*
-		 * CPTR_EL2.TAM: Set to zero so any accesses to
-		 * the Activity Monitor registers do not trap to EL2.
+		 * CPTR_EL2.TAM: Set to zero so any accesses to the Activity
+		 * Monitor registers do not trap to EL2.
 		 */
 		write_cptr_el2_tam(0U);
 	}
@@ -178,18 +216,29 @@
 	 */
 	write_cptr_el3_tam(ctx, 0U);
 
+	/*
+	 * Retrieve the number of architected counters. All of these counters
+	 * are enabled by default.
+	 */
+
-	/* Enable group 0 counters */
-	write_amcntenset0_el0_px((UINT64_C(1) << read_amcgcr_el0_cg0nc()) - 1U);
+	amcgcr_el0_cg0nc = read_amcgcr_el0_cg0nc();
+	amcntenset0_el0_px = (UINT64_C(1) << (amcgcr_el0_cg0nc)) - 1U;
 
-#if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		/* Enable group 1 counters */
-		write_amcntenset1_el0_px(AMU_GROUP1_COUNTERS_MASK);
+	assert(amcgcr_el0_cg0nc <= AMU_AMCGCR_CG0NC_MAX);
+
+	/*
+	 * Enable the requested counters.
+	 */
+
+	write_amcntenset0_el0_px(amcntenset0_el0_px);
+
+	amcfgr_el0_ncg = read_amcfgr_el0_ncg();
+	if (amcfgr_el0_ncg > 0U) {
+		write_amcntenset1_el0_px(amcntenset1_el0_px);
 	}
-#endif
 
 	/* Initialize FEAT_AMUv1p1 features if present. */
-	if (!amu_v1p1_supported()) {
+	if (id_aa64pfr0_el1_amu >= ID_AA64PFR0_AMU_V1P1) {
 		return;
 	}
 
@@ -233,6 +282,31 @@
 }
 
 /*
+ * Unlike with auxiliary counters, we cannot detect at runtime whether an
+ * architected counter supports a virtual offset. These are instead fixed
+ * according to FEAT_AMUv1p1, but this switch will need to be updated if later
+ * revisions of FEAT_AMU add additional architected counters.
+ */
+static bool amu_group0_voffset_supported(uint64_t idx)
+{
+	switch (idx) {
+	case 0U:
+	case 2U:
+	case 3U:
+		return true;
+
+	case 1U:
+		return false;
+
+	default:
+		ERROR("AMU: can't set up virtual offset for unknown "
+		      "architected counter %llu!\n", idx);
+
+		panic();
+	}
+}
+
+/*
  * Read the group 0 offset register for a given index. Index must be 0, 2,
  * or 3, the register for 1 does not exist.
  *
@@ -319,135 +393,192 @@
 
 static void *amu_context_save(const void *arg)
 {
-	struct amu_ctx *ctx = &amu_ctxs[plat_my_core_pos()];
-	unsigned int i;
+	uint64_t i, j;
 
-	if (!amu_supported()) {
-		return (void *)-1;
-	}
+	unsigned int core_pos;
+	struct amu_ctx *ctx;
 
-	/* Assert that group 0/1 counter configuration is what we expect */
-	assert(read_amcntenset0_el0_px() ==
-		((UINT64_C(1) << read_amcgcr_el0_cg0nc()) - 1U));
+	uint64_t id_aa64pfr0_el1_amu;	/* AMU version */
+	uint64_t hcr_el2_amvoffen;	/* AMU virtual offsets enabled */
+	uint64_t amcgcr_el0_cg0nc;	/* Number of group 0 counters */
 
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		assert(read_amcntenset1_el0_px() == AMU_GROUP1_COUNTERS_MASK);
+	uint64_t amcg1idr_el0_voff;	/* Auxiliary counters with virtual offsets */
+	uint64_t amcfgr_el0_ncg;	/* Number of counter groups */
+	uint64_t amcgcr_el0_cg1nc;	/* Number of group 1 counters */
+#endif
+
+	id_aa64pfr0_el1_amu = read_id_aa64pfr0_el1_amu();
+	if (id_aa64pfr0_el1_amu == ID_AA64PFR0_AMU_NOT_SUPPORTED) {
+		return (void *)0;
 	}
+
+	core_pos = plat_my_core_pos();
+	ctx = &amu_ctxs_[core_pos];
+
+	amcgcr_el0_cg0nc = read_amcgcr_el0_cg0nc();
+	hcr_el2_amvoffen = (id_aa64pfr0_el1_amu >= ID_AA64PFR0_AMU_V1P1) ?
+		read_hcr_el2_amvoffen() : 0U;
+
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	amcfgr_el0_ncg = read_amcfgr_el0_ncg();
+	amcgcr_el0_cg1nc = (amcfgr_el0_ncg > 0U) ? read_amcgcr_el0_cg1nc() : 0U;
+	amcg1idr_el0_voff = (hcr_el2_amvoffen != 0U) ? read_amcg1idr_el0_voff() : 0U;
 #endif
 
 	/*
-	 * Disable group 0/1 counters to avoid other observers like SCP sampling
-	 * counter values from the future via the memory mapped view.
+	 * Disable all AMU counters.
 	 */
-	write_amcntenclr0_el0_px((UINT64_C(1) << read_amcgcr_el0_cg0nc()) - 1U);
+
+	ctx->group0_enable = read_amcntenset0_el0_px();
+	write_amcntenclr0_el0_px(ctx->group0_enable);
 
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		write_amcntenclr1_el0_px(AMU_GROUP1_COUNTERS_MASK);
+	if (amcfgr_el0_ncg > 0U) {
+		ctx->group1_enable = read_amcntenset1_el0_px();
+		write_amcntenclr1_el0_px(ctx->group1_enable);
 	}
 #endif
 
-	isb();
+	/*
+	 * Save the counters to the local context.
+	 */
 
-	/* Save all group 0 counters */
-	for (i = 0U; i < read_amcgcr_el0_cg0nc(); i++) {
+	isb(); /* Ensure counters have been stopped */
+
+	for (i = 0U; i < amcgcr_el0_cg0nc; i++) {
 		ctx->group0_cnts[i] = amu_group0_cnt_read(i);
 	}
 
-	/* Save group 0 virtual offsets if supported and enabled. */
-	if (amu_v1p1_supported() && (read_hcr_el2_amvoffen() != 0U)) {
-		/* Not using a loop because count is fixed and index 1 DNE. */
-		ctx->group0_voffsets[0U] = amu_group0_voffset_read(0U);
-		ctx->group0_voffsets[1U] = amu_group0_voffset_read(2U);
-		ctx->group0_voffsets[2U] = amu_group0_voffset_read(3U);
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	for (i = 0U; i < amcgcr_el0_cg1nc; i++) {
+		ctx->group1_cnts[i] = amu_group1_cnt_read(i);
 	}
+#endif
 
-#if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		/* Save group 1 counters */
-		for (i = 0U; i < read_amcgcr_el0_cg1nc(); i++) {
-			if ((AMU_GROUP1_COUNTERS_MASK & (1UL << i)) != 0U) {
-				ctx->group1_cnts[i] = amu_group1_cnt_read(i);
+	/*
+	 * Save virtual offsets for counters that offer them.
+	 */
+
+	if (hcr_el2_amvoffen != 0U) {
+		for (i = 0U, j = 0U; i < amcgcr_el0_cg0nc; i++) {
+			if (!amu_group0_voffset_supported(i)) {
+				continue; /* No virtual offset */
 			}
-		}
 
-		/* Save group 1 virtual offsets if supported and enabled. */
-		if (amu_v1p1_supported() && (read_hcr_el2_amvoffen() != 0U)) {
-			uint64_t amcg1idr = read_amcg1idr_el0_voff() &
-				AMU_GROUP1_COUNTERS_MASK;
+			ctx->group0_voffsets[j++] = amu_group0_voffset_read(i);
+		}
 
-			for (i = 0U; i < read_amcgcr_el0_cg1nc(); i++) {
-				if (((amcg1idr >> i) & 1ULL) != 0ULL) {
-					ctx->group1_voffsets[i] =
-						amu_group1_voffset_read(i);
-				}
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+		for (i = 0U, j = 0U; i < amcgcr_el0_cg1nc; i++) {
+			if ((amcg1idr_el0_voff >> i) & 1U) {
+				continue; /* No virtual offset */
 			}
+
+			ctx->group1_voffsets[j++] = amu_group1_voffset_read(i);
 		}
-	}
 #endif
+	}
 
 	return (void *)0;
 }
 
 static void *amu_context_restore(const void *arg)
 {
-	struct amu_ctx *ctx = &amu_ctxs[plat_my_core_pos()];
-	unsigned int i;
+	uint64_t i, j;
 
-	if (!amu_supported()) {
-		return (void *)-1;
+	unsigned int core_pos;
+	struct amu_ctx *ctx;
+
+	uint64_t id_aa64pfr0_el1_amu;	/* AMU version */
+
+	uint64_t hcr_el2_amvoffen;	/* AMU virtual offsets enabled */
+
+	uint64_t amcfgr_el0_ncg;	/* Number of counter groups */
+	uint64_t amcgcr_el0_cg0nc;	/* Number of group 0 counters */
+
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	uint64_t amcgcr_el0_cg1nc;	/* Number of group 1 counters */
+	uint64_t amcg1idr_el0_voff;	/* Auxiliary counters with virtual offsets */
+#endif
+
+	id_aa64pfr0_el1_amu = read_id_aa64pfr0_el1_amu();
+	if (id_aa64pfr0_el1_amu == ID_AA64PFR0_AMU_NOT_SUPPORTED) {
+		return (void *)0;
 	}
 
-	/* Counters were disabled in `amu_context_save()` */
-	assert(read_amcntenset0_el0_px() == 0U);
+	core_pos = plat_my_core_pos();
+	ctx = &amu_ctxs_[core_pos];
+
+	amcfgr_el0_ncg = read_amcfgr_el0_ncg();
+	amcgcr_el0_cg0nc = read_amcgcr_el0_cg0nc();
+
+	hcr_el2_amvoffen = (id_aa64pfr0_el1_amu >= ID_AA64PFR0_AMU_V1P1) ?
+		read_hcr_el2_amvoffen() : 0U;
 
 #if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
+	amcgcr_el0_cg1nc = (amcfgr_el0_ncg > 0U) ? read_amcgcr_el0_cg1nc() : 0U;
+	amcg1idr_el0_voff = (hcr_el2_amvoffen != 0U) ? read_amcg1idr_el0_voff() : 0U;
+#endif
+
+	/*
+	 * Sanity check that all counters were disabled when the context was
+	 * previously saved.
+	 */
+
+	assert(read_amcntenset0_el0_px() == 0U);
+
+	if (amcfgr_el0_ncg > 0U) {
 		assert(read_amcntenset1_el0_px() == 0U);
 	}
-#endif
 
-	/* Restore all group 0 counters */
-	for (i = 0U; i < read_amcgcr_el0_cg0nc(); i++) {
+	/*
+	 * Restore the counter values from the local context.
+	 */
+
+	for (i = 0U; i < amcgcr_el0_cg0nc; i++) {
 		amu_group0_cnt_write(i, ctx->group0_cnts[i]);
 	}
 
-	/* Restore group 0 virtual offsets if supported and enabled. */
-	if (amu_v1p1_supported() && (read_hcr_el2_amvoffen() != 0U)) {
-		/* Not using a loop because count is fixed and index 1 DNE. */
-		amu_group0_voffset_write(0U, ctx->group0_voffsets[0U]);
-		amu_group0_voffset_write(2U, ctx->group0_voffsets[1U]);
-		amu_group0_voffset_write(3U, ctx->group0_voffsets[2U]);
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	for (i = 0U; i < amcgcr_el0_cg1nc; i++) {
+		amu_group1_cnt_write(i, ctx->group1_cnts[i]);
 	}
+#endif
 
-	/* Restore group 0 counter configuration */
-	write_amcntenset0_el0_px((UINT64_C(1) << read_amcgcr_el0_cg0nc()) - 1U);
+	/*
+	 * Restore virtual offsets for counters that offer them.
+	 */
 
-#if ENABLE_AMU_AUXILIARY_COUNTERS
-	if (amu_group1_supported()) {
-		/* Restore group 1 counters */
-		for (i = 0U; i < read_amcgcr_el0_cg1nc(); i++) {
-			if ((AMU_GROUP1_COUNTERS_MASK & (1UL << i)) != 0U) {
-				amu_group1_cnt_write(i, ctx->group1_cnts[i]);
+	if (hcr_el2_amvoffen != 0U) {
+		for (i = 0U, j = 0U; i < amcgcr_el0_cg0nc; i++) {
+			if (!amu_group0_voffset_supported(i)) {
+				continue; /* No virtual offset */
 			}
-		}
 
-		/* Restore group 1 virtual offsets if supported and enabled. */
-		if (amu_v1p1_supported() && (read_hcr_el2_amvoffen() != 0U)) {
-			uint64_t amcg1idr = read_amcg1idr_el0_voff() &
-				AMU_GROUP1_COUNTERS_MASK;
+			amu_group0_voffset_write(i, ctx->group0_voffsets[j++]);
+		}
 
-			for (i = 0U; i < read_amcgcr_el0_cg1nc(); i++) {
-				if (((amcg1idr >> i) & 1ULL) != 0ULL) {
-					amu_group1_voffset_write(i,
-						ctx->group1_voffsets[i]);
-				}
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+		for (i = 0U, j = 0U; i < amcgcr_el0_cg1nc; i++) {
+			if ((amcg1idr_el0_voff >> i) & 1U) {
+				continue; /* No virtual offset */
 			}
+
+			amu_group1_voffset_write(i, ctx->group1_voffsets[j++]);
 		}
+#endif
+	}
+
+	/*
+	 * Re-enable counters that were disabled during context save.
+	 */
+
+	write_amcntenset0_el0_px(ctx->group0_enable);
 
-		/* Restore group 1 counter configuration */
-		write_amcntenset1_el0_px(AMU_GROUP1_COUNTERS_MASK);
+#if ENABLE_AMU_AUXILIARY_COUNTERS
+	if (amcfgr_el0_ncg > 0) {
+		write_amcntenset1_el0_px(ctx->group1_enable);
 	}
 #endif
 
diff --git a/include/lib/extensions/amu_private.h b/lib/extensions/amu/amu_private.h
similarity index 65%
rename from include/lib/extensions/amu_private.h
rename to lib/extensions/amu/amu_private.h
index efa6f6c..eb7ff0e 100644
--- a/include/lib/extensions/amu_private.h
+++ b/lib/extensions/amu/amu_private.h
@@ -18,42 +18,21 @@
 #define AMU_GROUP0_MAX_COUNTERS		U(16)
 #define AMU_GROUP1_MAX_COUNTERS		U(16)
 
-#if ENABLE_AMU_AUXILIARY_COUNTERS
-#define AMU_GROUP1_COUNTERS_MASK	U(0)
-#endif
-
-struct amu_ctx {
-	uint64_t group0_cnts[AMU_GROUP0_MAX_COUNTERS];
-#if __aarch64__
-	/* Architected event counter 1 does not have an offset register. */
-	uint64_t group0_voffsets[AMU_GROUP0_MAX_COUNTERS-1];
-#endif
-
-#if ENABLE_AMU_AUXILIARY_COUNTERS
-	uint64_t group1_cnts[AMU_GROUP1_MAX_COUNTERS];
-#if __aarch64__
-	uint64_t group1_voffsets[AMU_GROUP1_MAX_COUNTERS];
-#endif
-#endif
-};
+#define AMU_AMCGCR_CG0NC_MAX		U(16)
 
 uint64_t amu_group0_cnt_read_internal(unsigned int idx);
 void amu_group0_cnt_write_internal(unsigned int idx, uint64_t val);
 
-#if ENABLE_AMU_AUXILIARY_COUNTERS
 uint64_t amu_group1_cnt_read_internal(unsigned int idx);
 void amu_group1_cnt_write_internal(unsigned int idx, uint64_t val);
 void amu_group1_set_evtype_internal(unsigned int idx, unsigned int val);
-#endif
 
 #if __aarch64__
 uint64_t amu_group0_voffset_read_internal(unsigned int idx);
 void amu_group0_voffset_write_internal(unsigned int idx, uint64_t val);
 
-#if ENABLE_AMU_AUXILIARY_COUNTERS
 uint64_t amu_group1_voffset_read_internal(unsigned int idx);
 void amu_group1_voffset_write_internal(unsigned int idx, uint64_t val);
 #endif
-#endif
 
 #endif /* AMU_PRIVATE_H */