Skip to content

AnalyticsEngine

The analytics engine exposes typed, frozen, Decimal-safe analytics functions on top of CanonicalDataset. Six submodules cover country, partner, commodity, time-series, balance, and comparison analytics.

API reference

The full reference is generated from the SDK's docstrings via mkdocstrings.

AnalyticsEngine

High-level orchestrator for the analytics layer.

A single AnalyticsEngine instance holds:

  • A filter chain — filters are applied in order; the resulting dataset is what each metric and aggregation sees.
  • A list of metrics — each is computed once on the filtered dataset.
  • A list of aggregations — each is computed once on the filtered dataset.

Construction is purely declarative; no work happens until run(dataset) is called.

Usage::

engine = (
    AnalyticsEngine(name="india_2022_summary")
    .add_filter(Filter.reporter(699))
    .add_filter(Filter.year(2022))
    .add_metric(Metric.count())
    .add_metric(Metric.sum_primary_value())
    .add_aggregation(
        Aggregation(
            name="by_partner",
            group_by=("partner_code",),
            metric=Metric.sum_primary_value(),
        )
    )
)
result = engine.run(dataset)
Source code in un_comtrade/analytics/__init__.py
class AnalyticsEngine:
    """High-level orchestrator for the analytics layer.

    A single `AnalyticsEngine` instance holds:

    - A **filter chain** — filters are applied in
      order; the resulting dataset is what each
      metric and aggregation sees.
    - A list of **metrics** — each is computed once
      on the filtered dataset.
    - A list of **aggregations** — each is computed
      once on the filtered dataset.

    Construction is purely declarative; no work
    happens until `run(dataset)` is called.

    Usage::

        engine = (
            AnalyticsEngine(name="india_2022_summary")
            .add_filter(Filter.reporter(699))
            .add_filter(Filter.year(2022))
            .add_metric(Metric.count())
            .add_metric(Metric.sum_primary_value())
            .add_aggregation(
                Aggregation(
                    name="by_partner",
                    group_by=("partner_code",),
                    metric=Metric.sum_primary_value(),
                )
            )
        )
        result = engine.run(dataset)
    """

    def __init__(
        self,
        *,
        name: str,
        config: Mapping[str, Any] | None = None,
    ) -> None:
        if not name:
            raise AnalyticsError("AnalyticsEngine.name must be non-empty")
        self._name = name
        self._config: dict[str, Any] = dict(config or {})
        self._filters: list[Filter] = []
        self._metrics: list[Metric] = []
        self._aggregations: list[Aggregation] = []

    # ----- Builder methods ------------------------------------------

    def add_filter(self, filter_: Filter) -> "AnalyticsEngine":
        """Append a `Filter` to the engine's filter
        chain. Returns `self` for chaining."""
        if not isinstance(filter_, Filter):
            raise AnalyticsError(
                f"add_filter expects a Filter; got "
                f"{type(filter_).__name__}"
            )
        self._filters.append(filter_)
        return self

    def add_metric(self, metric: Metric) -> "AnalyticsEngine":
        """Append a `Metric` to be computed against
        the filtered dataset. Returns `self` for
        chaining."""
        if not isinstance(metric, Metric):
            raise AnalyticsError(
                f"add_metric expects a Metric; got "
                f"{type(metric).__name__}"
            )
        self._metrics.append(metric)
        return self

    def add_aggregation(
        self, aggregation: Aggregation
    ) -> "AnalyticsEngine":
        """Append an `Aggregation` to be computed
        against the filtered dataset. Returns `self`
        for chaining."""
        if not isinstance(aggregation, Aggregation):
            raise AnalyticsError(
                f"add_aggregation expects an Aggregation; "
                f"got {type(aggregation).__name__}"
            )
        self._aggregations.append(aggregation)
        return self

    # ----- Read-only properties -------------------------------------

    @property
    def name(self) -> str:
        return self._name

    @property
    def config(self) -> Mapping[str, Any]:
        return dict(self._config)

    @property
    def filters(self) -> tuple[Filter, ...]:
        return tuple(self._filters)

    @property
    def metrics(self) -> tuple[Metric, ...]:
        return tuple(self._metrics)

    @property
    def aggregations(self) -> tuple[Aggregation, ...]:
        return tuple(self._aggregations)

    # ----- Execution -------------------------------------------------

    def run(self, dataset: CanonicalDataset) -> AnalysisResult:
        """Apply the filter chain, compute metrics,
        and run aggregations.

        Returns a frozen `AnalysisResult`. Raises
        `AnalyticsError` if `dataset` is not a
        `CanonicalDataset`. Metric / aggregation
        failures are surfaced as warnings on the
        result's context (rather than re-raised) so
        that one broken metric doesn't abort the
        whole analysis.
        """
        if not isinstance(dataset, CanonicalDataset):
            raise AnalyticsError(
                f"AnalyticsEngine.run source must be a "
                f"CanonicalDataset; got "
                f"{type(dataset).__name__}"
            )

        started = datetime.now(timezone.utc)
        start_perf = time.monotonic()
        warnings: list[str] = []
        errors: list[str] = []
        metric_durations: dict[str, float] = {}
        aggregation_durations: dict[str, float] = {}

        # Apply filter chain.
        current = dataset
        for f in self._filters:
            current = f.apply(current)

        # Compute metrics.
        metric_values: dict[str, NumericValue] = {}
        for m in self._metrics:
            t0 = time.monotonic()
            try:
                metric_values[m.name] = m.compute(current)
            except MetricError as exc:
                warnings.append(
                    f"metric {m.name!r} failed: {exc}"
                )
            except Exception as exc:  # pragma: no cover
                warnings.append(
                    f"metric {m.name!r} raised {type(exc).__name__}: "
                    f"{exc}"
                )
            finally:
                metric_durations[m.name] = time.monotonic() - t0

        # Run aggregations.
        aggregation_results: dict[
            str, tuple[AggregationRow, ...]
        ] = {}
        for a in self._aggregations:
            t0 = time.monotonic()
            try:
                aggregation_results[a.name] = a.apply(current)
            except AggregationError as exc:
                errors.append(
                    f"aggregation {a.name!r} failed: {exc}"
                )
            except Exception as exc:  # pragma: no cover
                errors.append(
                    f"aggregation {a.name!r} raised "
                    f"{type(exc).__name__}: {exc}"
                )
            finally:
                aggregation_durations[a.name] = time.monotonic() - t0

        finished = datetime.now(timezone.utc)
        duration = time.monotonic() - start_perf

        context = AnalysisContext(
            analysis_name=self._name,
            config=dict(self._config),
            warnings=tuple(warnings),
            errors=tuple(errors),
            started_at=started,
            finished_at=finished,
            metric_durations=dict(metric_durations),
            aggregation_durations=dict(aggregation_durations),
        )

        return AnalysisResult(
            analysis_name=self._name,
            metric_values=dict(metric_values),
            aggregation_results=dict(aggregation_results),
            record_count=len(dataset.records),
            filtered_count=len(current.records),
            context=context,
            duration_seconds=duration,
        )

    # ----- Repr -----------------------------------------------------

    def __repr__(self) -> str:
        return (
            f"AnalyticsEngine(name={self._name!r}, "
            f"filters={len(self._filters)}, "
            f"metrics={len(self._metrics)}, "
            f"aggregations={len(self._aggregations)})"
        )

add_filter

add_filter(filter_: Filter) -> 'AnalyticsEngine'

Append a Filter to the engine's filter chain. Returns self for chaining.

Source code in un_comtrade/analytics/__init__.py
def add_filter(self, filter_: Filter) -> "AnalyticsEngine":
    """Append a `Filter` to the engine's filter
    chain. Returns `self` for chaining."""
    if not isinstance(filter_, Filter):
        raise AnalyticsError(
            f"add_filter expects a Filter; got "
            f"{type(filter_).__name__}"
        )
    self._filters.append(filter_)
    return self

add_metric

add_metric(metric: Metric) -> 'AnalyticsEngine'

Append a Metric to be computed against the filtered dataset. Returns self for chaining.

Source code in un_comtrade/analytics/__init__.py
def add_metric(self, metric: Metric) -> "AnalyticsEngine":
    """Append a `Metric` to be computed against
    the filtered dataset. Returns `self` for
    chaining."""
    if not isinstance(metric, Metric):
        raise AnalyticsError(
            f"add_metric expects a Metric; got "
            f"{type(metric).__name__}"
        )
    self._metrics.append(metric)
    return self

add_aggregation

add_aggregation(
    aggregation: Aggregation,
) -> "AnalyticsEngine"

Append an Aggregation to be computed against the filtered dataset. Returns self for chaining.

Source code in un_comtrade/analytics/__init__.py
def add_aggregation(
    self, aggregation: Aggregation
) -> "AnalyticsEngine":
    """Append an `Aggregation` to be computed
    against the filtered dataset. Returns `self`
    for chaining."""
    if not isinstance(aggregation, Aggregation):
        raise AnalyticsError(
            f"add_aggregation expects an Aggregation; "
            f"got {type(aggregation).__name__}"
        )
    self._aggregations.append(aggregation)
    return self

run

run(dataset: CanonicalDataset) -> AnalysisResult

Apply the filter chain, compute metrics, and run aggregations.

Returns a frozen AnalysisResult. Raises AnalyticsError if dataset is not a CanonicalDataset. Metric / aggregation failures are surfaced as warnings on the result's context (rather than re-raised) so that one broken metric doesn't abort the whole analysis.

Source code in un_comtrade/analytics/__init__.py
def run(self, dataset: CanonicalDataset) -> AnalysisResult:
    """Apply the filter chain, compute metrics,
    and run aggregations.

    Returns a frozen `AnalysisResult`. Raises
    `AnalyticsError` if `dataset` is not a
    `CanonicalDataset`. Metric / aggregation
    failures are surfaced as warnings on the
    result's context (rather than re-raised) so
    that one broken metric doesn't abort the
    whole analysis.
    """
    if not isinstance(dataset, CanonicalDataset):
        raise AnalyticsError(
            f"AnalyticsEngine.run source must be a "
            f"CanonicalDataset; got "
            f"{type(dataset).__name__}"
        )

    started = datetime.now(timezone.utc)
    start_perf = time.monotonic()
    warnings: list[str] = []
    errors: list[str] = []
    metric_durations: dict[str, float] = {}
    aggregation_durations: dict[str, float] = {}

    # Apply filter chain.
    current = dataset
    for f in self._filters:
        current = f.apply(current)

    # Compute metrics.
    metric_values: dict[str, NumericValue] = {}
    for m in self._metrics:
        t0 = time.monotonic()
        try:
            metric_values[m.name] = m.compute(current)
        except MetricError as exc:
            warnings.append(
                f"metric {m.name!r} failed: {exc}"
            )
        except Exception as exc:  # pragma: no cover
            warnings.append(
                f"metric {m.name!r} raised {type(exc).__name__}: "
                f"{exc}"
            )
        finally:
            metric_durations[m.name] = time.monotonic() - t0

    # Run aggregations.
    aggregation_results: dict[
        str, tuple[AggregationRow, ...]
    ] = {}
    for a in self._aggregations:
        t0 = time.monotonic()
        try:
            aggregation_results[a.name] = a.apply(current)
        except AggregationError as exc:
            errors.append(
                f"aggregation {a.name!r} failed: {exc}"
            )
        except Exception as exc:  # pragma: no cover
            errors.append(
                f"aggregation {a.name!r} raised "
                f"{type(exc).__name__}: {exc}"
            )
        finally:
            aggregation_durations[a.name] = time.monotonic() - t0

    finished = datetime.now(timezone.utc)
    duration = time.monotonic() - start_perf

    context = AnalysisContext(
        analysis_name=self._name,
        config=dict(self._config),
        warnings=tuple(warnings),
        errors=tuple(errors),
        started_at=started,
        finished_at=finished,
        metric_durations=dict(metric_durations),
        aggregation_durations=dict(aggregation_durations),
    )

    return AnalysisResult(
        analysis_name=self._name,
        metric_values=dict(metric_values),
        aggregation_results=dict(aggregation_results),
        record_count=len(dataset.records),
        filtered_count=len(current.records),
        context=context,
        duration_seconds=duration,
    )

country

Country-level analytics (P6-002).

This module is the first concrete analytics submodule built on top of the AnalyticsEngine foundation (P6-001). It provides five country-level analytics that operate exclusively on CanonicalDataset:

  • total_imports(...) — sum of imports for a given reporter (optionally filtered by year / window).
  • total_exports(...) — sum of exports for a given reporter (optionally filtered by year / window).
  • country_ranking(...) — rank reporters by total trade / exports / imports, with optional flow filter and limit.
  • country_summary(...) — one-stop summary per reporter: totals, balance, partner count, year range.
  • country_trend(...) — exports / imports / balance per year (or per period) for a given reporter.

All functions accept a CanonicalDataset and return either a Decimal (for total_*), a frozen dataclass (for summary / trend / ranking), or a tuple of frozen dataclasses (for ranking).

QE-007 refactor: this module's filter / group / aggregate / sort logic is now built on top of the internal Query engine (see un_comtrade.analytics._query_engine). The public API is unchanged; only the internal implementation now delegates to Query(...), Query.filter(...), Query.group_by(...), sum(...), summarize(...), and Query.sort(...).

The dataclasses are frozen (ADR-0013) and use Decimal for monetary values (ADR-0027).

The module is decoupled from the transport layer (same constraint as AnalyticsEngine): only stdlib + intra-package imports.

CountryAnalyticsError

Bases: Exception

Raised when a country-level analytics operation cannot be performed (e.g. unknown ranking field, missing reporter).

Source code in un_comtrade/analytics/country.py
class CountryAnalyticsError(Exception):
    """Raised when a country-level analytics
    operation cannot be performed (e.g. unknown
    ranking field, missing reporter).
    """

CountryRankingRow dataclass

One row of a country ranking.

Captures totals for a single reporter plus the ISO3 / name metadata if present in the source records.

Source code in un_comtrade/analytics/country.py
@dataclass(frozen=True)
class CountryRankingRow:
    """One row of a country ranking.

    Captures totals for a single reporter plus the
    ISO3 / name metadata if present in the source
    records.
    """

    reporter_code: int
    reporter_iso3: str | None
    reporter_name: str | None
    total_exports: Decimal
    total_imports: Decimal
    total_trade_value: Decimal
    trade_balance: Decimal
    record_count: int

    def __post_init__(self) -> None:
        if not isinstance(
            self.total_exports, Decimal
        ) or not isinstance(
            self.total_imports, Decimal
        ):
            raise CountryAnalyticsError(
                "total_exports / total_imports must be "
                "Decimal"
            )
        if not isinstance(self.total_trade_value, Decimal):
            raise CountryAnalyticsError(
                "total_trade_value must be Decimal"
            )
        if not isinstance(self.trade_balance, Decimal):
            raise CountryAnalyticsError(
                "trade_balance must be Decimal"
            )

CountrySummary dataclass

One-stop summary of a single reporter's activity in a CanonicalDataset.

Captures totals, trade balance, partner count, record count, and the observed year range.

Source code in un_comtrade/analytics/country.py
@dataclass(frozen=True)
class CountrySummary:
    """One-stop summary of a single reporter's
    activity in a `CanonicalDataset`.

    Captures totals, trade balance, partner count,
    record count, and the observed year range.
    """

    reporter_code: int
    reporter_iso3: str | None
    reporter_name: str | None
    total_exports: Decimal
    total_imports: Decimal
    total_trade: Decimal
    trade_balance: Decimal
    partner_count: int
    record_count: int
    year_range: tuple[int, int] | None

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "total_trade", "trade_balance",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise CountryAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

CountryTrendPoint dataclass

One point on a country trend (one year or one period).

Source code in un_comtrade/analytics/country.py
@dataclass(frozen=True)
class CountryTrendPoint:
    """One point on a country trend (one year or
    one period)."""

    year: int
    period: str
    exports: Decimal
    imports: Decimal
    total_trade: Decimal
    trade_balance: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in ("exports", "imports", "total_trade", "trade_balance"):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise CountryAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

CountryTrend dataclass

Time-series of country activity for one reporter.

points is sorted by (year, period) in ascending order.

Source code in un_comtrade/analytics/country.py
@dataclass(frozen=True)
class CountryTrend:
    """Time-series of country activity for one
    reporter.

    `points` is sorted by (year, period) in
    ascending order.
    """

    reporter_code: int
    points: tuple[CountryTrendPoint, ...] = field(
        default_factory=tuple
    )

    @property
    def years(self) -> tuple[int, ...]:
        return tuple(sorted({p.year for p in self.points}))

    @property
    def total_exports(self) -> Decimal:
        return _sum_primary_value_iter(p.exports for p in self.points)

    @property
    def total_imports(self) -> Decimal:
        return _sum_primary_value_iter(p.imports for p in self.points)

    @property
    def total_trade(self) -> Decimal:
        return self.total_exports + self.total_imports

total_imports

total_imports(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    year: int | None = None,
    years: tuple[int, ...] | None = None,
) -> Decimal

Sum of imports (flow_code == "M") for the optional filters.

Parameters

dataset The CanonicalDataset to analyse. reporter_code If supplied, only records whose reporter.reporter_code == reporter_code contribute. year If supplied, only records with ref_year == year contribute. years If supplied, only records whose ref_year is in this tuple contribute. Mutually exclusive with year.

Returns

Decimal Total import value (USD). Returns Decimal("0") when no records match.

Source code in un_comtrade/analytics/country.py
def total_imports(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    year: int | None = None,
    years: tuple[int, ...] | None = None,
) -> Decimal:
    """Sum of imports (`flow_code == "M"`) for the
    optional filters.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        If supplied, only records whose
        `reporter.reporter_code == reporter_code`
        contribute.
    year
        If supplied, only records with
        `ref_year == year` contribute.
    years
        If supplied, only records whose `ref_year`
        is in this tuple contribute. Mutually
        exclusive with `year`.

    Returns
    -------
    Decimal
        Total import value (USD). Returns
        `Decimal("0")` when no records match.
    """
    if year is not None and years is not None:
        raise CountryAnalyticsError(
            "year and years are mutually exclusive"
        )
    _check_dataset(dataset, fn_name="total_imports")
    records = _filter_records(
        dataset,
        reporter_code=reporter_code,
        flow_code="M",
        year=year,
        years=years,
    )
    return _sum_primary_value(records)

total_exports

total_exports(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    year: int | None = None,
    years: tuple[int, ...] | None = None,
) -> Decimal

Sum of exports (flow_code == "X") for the optional filters. Mirror of total_imports.

Source code in un_comtrade/analytics/country.py
def total_exports(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    year: int | None = None,
    years: tuple[int, ...] | None = None,
) -> Decimal:
    """Sum of exports (`flow_code == "X"`) for the
    optional filters. Mirror of `total_imports`."""
    if year is not None and years is not None:
        raise CountryAnalyticsError(
            "year and years are mutually exclusive"
        )
    _check_dataset(dataset, fn_name="total_exports")
    records = _filter_records(
        dataset,
        reporter_code=reporter_code,
        flow_code="X",
        year=year,
        years=years,
    )
    return _sum_primary_value(records)

country_ranking

country_ranking(
    dataset: CanonicalDataset,
    *,
    flow: str | None = None,
    by: str = "total_trade_value",
    descending: bool = True,
    limit: int | None = None,
) -> tuple[CountryRankingRow, ...]

Rank reporters by total trade (or by a specific flow / metric).

Parameters

dataset The CanonicalDataset to analyse. flow Optional flow filter. "X" keeps exports only; "M" keeps imports only; None (default) keeps both flows (totals are exports + imports). by One of "total_trade_value" (default), "exports", "imports", "trade_balance", or "record_count". descending When True (default), largest values first; when False, smallest values first. limit If supplied, return only the top limit rows (after sorting).

Returns

tuple[CountryRankingRow, ...] Rows sorted by by in the requested direction. Empty tuple if no records match.

Source code in un_comtrade/analytics/country.py
def country_ranking(
    dataset: CanonicalDataset,
    *,
    flow: str | None = None,
    by: str = "total_trade_value",
    descending: bool = True,
    limit: int | None = None,
) -> tuple[CountryRankingRow, ...]:
    """Rank reporters by total trade (or by a
    specific flow / metric).

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    flow
        Optional flow filter. `"X"` keeps exports
        only; `"M"` keeps imports only; `None`
        (default) keeps both flows (totals are
        exports + imports).
    by
        One of `"total_trade_value"` (default),
        `"exports"`, `"imports"`, `"trade_balance"`,
        or `"record_count"`.
    descending
        When `True` (default), largest values
        first; when `False`, smallest values
        first.
    limit
        If supplied, return only the top `limit`
        rows (after sorting).

    Returns
    -------
    tuple[CountryRankingRow, ...]
        Rows sorted by `by` in the requested
        direction. Empty tuple if no records match.
    """
    if by not in _COUNTRY_RANKING_FIELDS:
        raise CountryAnalyticsError(
            f"Unknown ranking field {by!r}; "
            f"valid: {sorted(_COUNTRY_RANKING_FIELDS)}"
        )
    if limit is not None and limit < 0:
        raise CountryAnalyticsError(
            "limit must be non-negative"
        )
    _check_dataset(dataset, fn_name="country_ranking")

    # Aggregate by reporter_code using the
    # internal Query engine (QE-007).
    # Step 1: apply the flow filter (if any)
    # and group by reporter_code.
    q = Query(dataset)
    if flow is not None:
        q = q.filter(flow_code=flow)
    q = q.group_by("reporter_code")
    result = q.execute()
    if not result.groups:
        return ()

    # For flow=None we still need both
    # exports and imports totals per
    # reporter; compute those via the
    # Query engine too.
    by_reporter_x: dict[int, Decimal] = {}
    by_reporter_m: dict[int, Decimal] = {}
    if flow is None:
        # No flow filter: capture both
        # X and M sums in one pass via two
        # queries.
        qx = Query(dataset).filter(flow_code="X")
        rx = qx.group_by("reporter_code").execute()
        for group in rx.groups:
            s = _q_summarize(
                group.records, field="primary_value"
            )
            code = group.key[0]
            by_reporter_x[code] = (
                s.sum if s.sum is not None
                else Decimal("0")
            )
        qm = Query(dataset).filter(flow_code="M")
        rm = qm.group_by("reporter_code").execute()
        for group in rm.groups:
            s = _q_summarize(
                group.records, field="primary_value"
            )
            code = group.key[0]
            by_reporter_m[code] = (
                s.sum if s.sum is not None
                else Decimal("0")
            )

    # Capture reporter metadata (iso3 / name)
    # from the source dataset, indexed by
    # reporter_code.
    meta: dict[int, dict[str, Any]] = {}
    for record in dataset.records:
        code = record.reporter.reporter_code
        if code not in meta:
            meta[code] = {
                "iso3": record.reporter.iso3,
                "name": record.reporter.name,
            }

    rows_by_code: dict[int, CountryRankingRow] = {}
    for group in result.groups:
        code = group.key[0]
        # `summarize` gives us count and sum
        # in one pass.
        agg = _q_summarize(
            group.records, field="primary_value"
        )
        flow_total = (
            agg.sum if agg.sum is not None
            else Decimal("0")
        )
        if flow == "X":
            x_value = flow_total
            m_value = Decimal("0")
        elif flow == "M":
            x_value = Decimal("0")
            m_value = flow_total
        else:
            x_value = by_reporter_x.get(
                code, Decimal("0")
            )
            m_value = by_reporter_m.get(
                code, Decimal("0")
            )
        total_trade = x_value + m_value
        balance = x_value - m_value
        rows_by_code[code] = CountryRankingRow(
            reporter_code=code,
            reporter_iso3=meta.get(code, {}).get("iso3"),
            reporter_name=meta.get(code, {}).get("name"),
            total_exports=x_value,
            total_imports=m_value,
            total_trade_value=total_trade,
            trade_balance=balance,
            record_count=agg.count,
        )

    # Sort by the requested field via the
    # Query engine (sort key columnar).
    def _sort_key(row: CountryRankingRow):
        if by == "total_trade_value":
            return row.total_trade_value
        if by == "exports":
            return row.total_exports
        if by == "imports":
            return row.total_imports
        if by == "trade_balance":
            return row.trade_balance
        if by == "record_count":
            return row.record_count
        raise CountryAnalyticsError(f"unreachable: {by}")

    # Use Query.sort + limit for the final
    # ranking. We sort the rows through the
    # Query engine for consistency; for
    # arbitrary row-derived sort keys we
    # fall back to Python's sorted().
    # Convert rows back to a Query for sort
    # ordering — but since rows are
    # pre-built CountryRankingRow objects,
    # we sort with Python's sorted (the
    # Query engine sorts records, not
    # arbitrary dataclasses).
    rows = sorted(
        rows_by_code.values(),
        key=_sort_key,
        reverse=descending,
    )
    if limit is not None:
        rows = rows[:limit]
    return tuple(rows)

country_summary

country_summary(
    dataset: CanonicalDataset, reporter_code: int
) -> CountrySummary | None

Build a CountrySummary for one reporter.

Returns None when the reporter has no records in the dataset.

Source code in un_comtrade/analytics/country.py
def country_summary(
    dataset: CanonicalDataset,
    reporter_code: int,
) -> CountrySummary | None:
    """Build a `CountrySummary` for one reporter.

    Returns `None` when the reporter has no records
    in the dataset.
    """
    _check_dataset(dataset, fn_name="country_summary")
    # QE-007 refactor: filter via Query
    # engine.
    records = _filter_records(
        dataset, reporter_code=reporter_code
    )
    if not records:
        return None

    iso3 = records[0].reporter.iso3
    name = records[0].reporter.name
    # QE-007 refactor: sums delegated to
    # `_q_sum` (the Query engine's
    # aggregation).
    exports = _sum_primary_value(
        r for r in records if r.flow.flow_code == "X"
    )
    imports = _sum_primary_value(
        r for r in records if r.flow.flow_code == "M"
    )
    partner_codes = {r.partner.partner_code for r in records}
    years = [r.ref_year for r in records]
    year_range: tuple[int, int] | None = (
        (min(years), max(years)) if years else None
    )

    return CountrySummary(
        reporter_code=reporter_code,
        reporter_iso3=iso3,
        reporter_name=name,
        total_exports=exports,
        total_imports=imports,
        total_trade=exports + imports,
        trade_balance=exports - imports,
        partner_count=len(partner_codes),
        record_count=len(records),
        year_range=year_range,
    )

country_trend

country_trend(
    dataset: CanonicalDataset,
    reporter_code: int,
    *,
    granularity: str = "year",
) -> CountryTrend

Build a CountryTrend for one reporter.

Parameters

dataset The CanonicalDataset to analyse. reporter_code The reporter to summarise. granularity "year" (default) groups by ref_year; "period" groups by period (e.g. "2022", "202201", etc.). "year" produces one point per calendar year; "period" can produce intra-year points.

Returns

CountryTrend Trend with points sorted by (year, period). Returns an empty CountryTrend when the reporter has no records.

Source code in un_comtrade/analytics/country.py
def country_trend(
    dataset: CanonicalDataset,
    reporter_code: int,
    *,
    granularity: str = "year",
) -> CountryTrend:
    """Build a `CountryTrend` for one reporter.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        The reporter to summarise.
    granularity
        `"year"` (default) groups by `ref_year`;
        `"period"` groups by `period` (e.g.
        `"2022"`, `"202201"`, etc.). `"year"`
        produces one point per calendar year;
        `"period"` can produce intra-year points.

    Returns
    -------
    CountryTrend
        Trend with `points` sorted by
        `(year, period)`. Returns an empty
        `CountryTrend` when the reporter has no
        records.
    """
    if granularity not in ("year", "period"):
        raise CountryAnalyticsError(
            f"Unknown granularity {granularity!r}; "
            f"valid: 'year', 'period'"
        )
    _check_dataset(dataset, fn_name="country_trend")

    records = _filter_records(
        dataset, reporter_code=reporter_code
    )
    if not records:
        return CountryTrend(
            reporter_code=reporter_code, points=()
        )

    # QE-007 refactor: group via the Query
    # engine's `.group_by(...)`. We group by
    # `ref_year` for "year" granularity, and
    # by `period` for "period" granularity.
    # Multi-field grouping produces tuple keys
    # of length 1.
    group_field = "ref_year" if granularity == "year" else "period"
    q = Query(dataset).filter(reporter_code=reporter_code)
    result = q.group_by(group_field).execute()

    points = []
    for group in result.groups:
        # group.key is (year,) or (period,).
        key_value = group.key[0]
        # Compute exports and imports
        # separately via Query engine.
        # We do this via two sum() calls on
        # the group's records (already
        # filtered to the right reporter +
        # group).
        x_records = [
            r for r in group.records
            if r.flow.flow_code == "X"
        ]
        m_records = [
            r for r in group.records
            if r.flow.flow_code == "M"
        ]
        exports = _sum_primary_value(x_records)
        imports = _sum_primary_value(m_records)
        if granularity == "year":
            year = key_value
            period = group.records[0].period
        else:
            year = group.records[0].ref_year
            period = key_value
        points.append(
            CountryTrendPoint(
                year=year,
                period=period,
                exports=exports,
                imports=imports,
                total_trade=exports + imports,
                trade_balance=exports - imports,
                record_count=len(group.records),
            )
        )
    points.sort(key=lambda p: (p.year, p.period))
    return CountryTrend(
        reporter_code=reporter_code,
        points=tuple(points),
    )

partner

Partner-level analytics (P6-003).

This module is the second concrete analytics submodule built on top of the AnalyticsEngine foundation (P6-001). It provides four partner-level analytics that operate exclusively on CanonicalDataset:

  • top_partners(...) — rank partners by trade value for a given reporter.
  • partner_growth(...) — year-over-year (or period-over-period) growth of a specific partner's trade with the reporter.
  • partner_balance(...) — exports minus imports per partner for a given reporter.
  • bilateral_summary(...) — comprehensive summary of trade between two reporters (or a reporter and a partner), including the mirror flow from the partner's perspective.

All functions accept a CanonicalDataset and return either a frozen dataclass (for bilateral_summary) or a tuple of frozen dataclasses (for top_partners, partner_balance). The growth function returns a PartnerGrowth container that includes both the per-period points and the absolute / relative change summary.

The module reuses the Filter, Metric, and Aggregation primitives from the parent AnalyticsEngine — no new abstractions are introduced. The dataclasses are frozen (ADR-0013) and use Decimal for monetary values (ADR-0027).

The module is decoupled from the transport layer (same constraint as AnalyticsEngine): only stdlib + intra-package imports.

PartnerAnalyticsError

Bases: AnalyticsError

Raised when a partner-level analytics operation cannot be performed (e.g. unknown ranking field, missing reporter / partner).

Source code in un_comtrade/analytics/partner.py
class PartnerAnalyticsError(AnalyticsError):
    """Raised when a partner-level analytics
    operation cannot be performed (e.g. unknown
    ranking field, missing reporter / partner)."""

PartnerRankingRow dataclass

One row of a partner ranking.

Captures totals for a single partner (relative to a fixed reporter) plus the ISO3 / name metadata if present in the source records.

Source code in un_comtrade/analytics/partner.py
@dataclass(frozen=True)
class PartnerRankingRow:
    """One row of a partner ranking.

    Captures totals for a single partner (relative
    to a fixed reporter) plus the ISO3 / name
    metadata if present in the source records.
    """

    partner_code: int
    partner_iso3: str | None
    partner_name: str | None
    total_exports: Decimal
    total_imports: Decimal
    total_trade: Decimal
    trade_balance: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "total_trade", "trade_balance",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise PartnerAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

PartnerGrowthPoint dataclass

One point on a partner growth series.

Source code in un_comtrade/analytics/partner.py
@dataclass(frozen=True)
class PartnerGrowthPoint:
    """One point on a partner growth series."""

    year: int
    period: str
    total_trade: Decimal
    exports: Decimal
    imports: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in ("total_trade", "exports", "imports"):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise PartnerAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

PartnerGrowth dataclass

Time-series of partner growth for one reporter / partner pair.

points is sorted by (year, period). The absolute_change is last_total_trade - first_total_trade. The relative_change is (last - first) / first when first != 0, else None. The cagr is the compound annual growth rate when there are ≥ 2 points spanning at least 1 year, else None.

Source code in un_comtrade/analytics/partner.py
@dataclass(frozen=True)
class PartnerGrowth:
    """Time-series of partner growth for one
    reporter / partner pair.

    `points` is sorted by `(year, period)`. The
    `absolute_change` is `last_total_trade -
    first_total_trade`. The `relative_change` is
    `(last - first) / first` when `first != 0`,
    else `None`. The `cagr` is the compound
    annual growth rate when there are ≥ 2 points
    spanning at least 1 year, else `None`.
    """

    reporter_code: int
    partner_code: int
    points: tuple[PartnerGrowthPoint, ...] = field(
        default_factory=tuple
    )
    absolute_change: Decimal = Decimal("0")
    relative_change: Decimal | None = None
    cagr: Decimal | None = None

    @property
    def years(self) -> tuple[int, ...]:
        return tuple(sorted({p.year for p in self.points}))

PartnerBalanceRow dataclass

One row of a partner balance view.

Sibling of PartnerRankingRow — kept as a separate type so callers can opt into the balance view semantically.

Source code in un_comtrade/analytics/partner.py
@dataclass(frozen=True)
class PartnerBalanceRow:
    """One row of a partner balance view.

    Sibling of `PartnerRankingRow` — kept as a
    separate type so callers can opt into the
    balance view semantically.
    """

    partner_code: int
    partner_iso3: str | None
    partner_name: str | None
    total_exports: Decimal
    total_imports: Decimal
    trade_balance: Decimal
    total_trade: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "trade_balance", "total_trade",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise PartnerAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

BilateralSummary dataclass

Comprehensive summary of trade between a reporter and a partner.

Captures BOTH sides of the relationship:

  • reporter_to_partner_exports / reporter_to_partner_imports — flows reported by reporter_code with partner_code as counterparty.
  • partner_to_reporter_exports / partner_to_reporter_imports — mirror flows where the counterparty is the reporter and the partner is the partner (i.e. reporter == partner_code and partner == reporter_code). Useful for reconciling asymmetries between the two sides' reporting.

Returns None from bilateral_summary(...) when the pair has no records on either side.

Source code in un_comtrade/analytics/partner.py
@dataclass(frozen=True)
class BilateralSummary:
    """Comprehensive summary of trade between a
    reporter and a partner.

    Captures BOTH sides of the relationship:

    - `reporter_to_partner_exports` /
      `reporter_to_partner_imports` — flows
      reported by `reporter_code` with
      `partner_code` as counterparty.
    - `partner_to_reporter_exports` /
      `partner_to_reporter_imports` — mirror
      flows where the counterparty is the
      reporter and the partner is the partner
      (i.e. `reporter == partner_code` and
      `partner == reporter_code`). Useful for
      reconciling asymmetries between the two
      sides' reporting.

    Returns `None` from `bilateral_summary(...)`
    when the pair has no records on either side.
    """

    reporter_code: int
    partner_code: int
    partner_iso3: str | None
    partner_name: str | None
    reporter_to_partner_exports: Decimal
    reporter_to_partner_imports: Decimal
    partner_to_reporter_exports: Decimal
    partner_to_reporter_imports: Decimal
    total_exports: Decimal
    total_imports: Decimal
    total_trade: Decimal
    record_count: int
    year_range: tuple[int, int] | None

    def __post_init__(self) -> None:
        for f in (
            "reporter_to_partner_exports",
            "reporter_to_partner_imports",
            "partner_to_reporter_exports",
            "partner_to_reporter_imports",
            "total_exports", "total_imports",
            "total_trade",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise PartnerAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

top_partners

top_partners(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    flow: str | None = None,
    by: str = "total_trade",
    descending: bool = True,
    limit: int | None = None,
) -> tuple[PartnerRankingRow, ...]

Rank partners by trade value for a fixed reporter.

Parameters

dataset The CanonicalDataset to analyse. reporter_code The reporter whose partners to rank. flow Optional flow filter. "X" keeps exports only; "M" keeps imports only; None (default) keeps both flows (totals are exports + imports). by One of "total_trade" (default), "exports", "imports", "trade_balance", "abs_trade_balance", or "record_count". descending When True (default), largest values first. limit If supplied, return only the top limit rows (after sorting).

Returns

tuple[PartnerRankingRow, ...] Sorted by by. Empty tuple when no partners match.

Source code in un_comtrade/analytics/partner.py
def top_partners(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    flow: str | None = None,
    by: str = "total_trade",
    descending: bool = True,
    limit: int | None = None,
) -> tuple[PartnerRankingRow, ...]:
    """Rank partners by trade value for a fixed
    reporter.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        The reporter whose partners to rank.
    flow
        Optional flow filter. `"X"` keeps exports
        only; `"M"` keeps imports only; `None`
        (default) keeps both flows (totals are
        exports + imports).
    by
        One of `"total_trade"` (default),
        `"exports"`, `"imports"`, `"trade_balance"`,
        `"abs_trade_balance"`, or `"record_count"`.
    descending
        When `True` (default), largest values
        first.
    limit
        If supplied, return only the top `limit`
        rows (after sorting).

    Returns
    -------
    tuple[PartnerRankingRow, ...]
        Sorted by `by`. Empty tuple when no
        partners match.
    """
    if by not in _PARTNER_RANKING_FIELDS:
        raise PartnerAnalyticsError(
            f"Unknown ranking field {by!r}; "
            f"valid: {sorted(_PARTNER_RANKING_FIELDS)}"
        )
    if limit is not None and limit < 0:
        raise PartnerAnalyticsError(
            "limit must be non-negative"
        )
    _check_canonical_dataset(dataset, fn_name="top_partners")

    records = _select_records(
        dataset,
        reporter_code=reporter_code,
        flow_code=flow,
    )
    if not records:
        return ()

    # QE-007 refactor: group records via the
    # Query engine. Two queries capture X
    # and M totals; one captures counts.
    # This replaces the hand-rolled
    # `by_partner_x` / `by_partner_m` dicts.
    by_partner_x: dict[int, Decimal] = {}
    by_partner_m: dict[int, Decimal] = {}
    meta: dict[int, dict[str, Any]] = {}
    counts: dict[int, int] = {}

    # Single pass over records for metadata
    # (iso3 / name) and counts (we still
    # need these even though the Query
    # engine could compute counts too,
    # because the partner.partner_code is
    # the group key).
    for record in records:
        code = record.partner.partner_code
        if code not in meta:
            meta[code] = {
                "iso3": record.partner.iso3,
                "name": record.partner.name,
            }
        counts[code] = counts.get(code, 0) + 1

    # X totals via the Query engine.
    qx = (
        Query(dataset)
        .filter(reporter_code=reporter_code)
        .filter(flow_code="X")
        .group_by("partner_code")
    )
    for group in qx.execute().groups:
        code = group.key[0]
        s = _q_summarize(
            group.records, field="primary_value"
        )
        by_partner_x[code] = (
            s.sum if s.sum is not None else Decimal("0")
        )

    # M totals via the Query engine.
    qm = (
        Query(dataset)
        .filter(reporter_code=reporter_code)
        .filter(flow_code="M")
        .group_by("partner_code")
    )
    for group in qm.execute().groups:
        code = group.key[0]
        s = _q_summarize(
            group.records, field="primary_value"
        )
        by_partner_m[code] = (
            s.sum if s.sum is not None else Decimal("0")
        )

    # If a flow filter is in effect, suppress the
    # counter-flow values so the rank focuses on
    # the requested flow.
    if flow == "X":
        for code in by_partner_m:
            by_partner_m[code] = Decimal("0")
    elif flow == "M":
        for code in by_partner_x:
            by_partner_x[code] = Decimal("0")

    rows: list[PartnerRankingRow] = []
    for code in sorted(counts):
        x = by_partner_x.get(code, Decimal("0"))
        m = by_partner_m.get(code, Decimal("0"))
        rows.append(
            PartnerRankingRow(
                partner_code=code,
                partner_iso3=meta[code].get("iso3"),
                partner_name=meta[code].get("name"),
                total_exports=x,
                total_imports=m,
                total_trade=x + m,
                trade_balance=x - m,
                record_count=counts[code],
            )
        )

    def _sort_key(row: PartnerRankingRow):
        if by == "total_trade":
            return row.total_trade
        if by == "exports":
            return row.total_exports
        if by == "imports":
            return row.total_imports
        if by == "trade_balance":
            return row.trade_balance
        if by == "abs_trade_balance":
            return abs(row.trade_balance)
        if by == "record_count":
            return row.record_count
        raise PartnerAnalyticsError(f"unreachable: {by}")

    rows.sort(key=_sort_key, reverse=descending)
    if limit is not None:
        rows = rows[:limit]
    return tuple(rows)

partner_growth

partner_growth(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    partner_code: int,
    granularity: str = "year",
) -> PartnerGrowth

Compute partner growth for one reporter / partner pair.

Parameters

dataset The CanonicalDataset to analyse. reporter_code The reporter whose side of the trade is counted. partner_code The partner whose growth is computed. granularity "year" (default) groups by ref_year; "period" groups by period string.

Returns

PartnerGrowth Container with sorted per-period points plus absolute / relative change summary and CAGR. Returns an empty PartnerGrowth when the pair has no records.

Source code in un_comtrade/analytics/partner.py
def partner_growth(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    partner_code: int,
    granularity: str = "year",
) -> PartnerGrowth:
    """Compute partner growth for one
    reporter / partner pair.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        The reporter whose side of the trade is
        counted.
    partner_code
        The partner whose growth is computed.
    granularity
        `"year"` (default) groups by `ref_year`;
        `"period"` groups by `period` string.

    Returns
    -------
    PartnerGrowth
        Container with sorted per-period
        `points` plus absolute / relative
        change summary and CAGR. Returns an
        empty `PartnerGrowth` when the pair has
        no records.
    """
    if granularity not in ("year", "period"):
        raise PartnerAnalyticsError(
            f"Unknown granularity {granularity!r}; "
            f"valid: 'year', 'period'"
        )
    _check_canonical_dataset(dataset, fn_name="partner_growth")

    records = _select_records(
        dataset,
        reporter_code=reporter_code,
        partner_code=partner_code,
    )
    if not records:
        return PartnerGrowth(
            reporter_code=reporter_code,
            partner_code=partner_code,
        )

    # Group by (year, period).
    bucket: dict[tuple[int, str], list] = {}
    for r in records:
        key = (r.ref_year, r.period)
        bucket.setdefault(key, []).append(r)

    points: list[PartnerGrowthPoint] = []
    for (year, period), group in bucket.items():
        x = _sum_primary_value(
            r for r in group if r.flow.flow_code == "X"
        )
        m = _sum_primary_value(
            r for r in group if r.flow.flow_code == "M"
        )
        points.append(
            PartnerGrowthPoint(
                year=year,
                period=period,
                total_trade=x + m,
                exports=x,
                imports=m,
                record_count=len(group),
            )
        )
    points.sort(key=lambda p: (p.year, p.period))

    first = points[0].total_trade
    last = points[-1].total_trade
    abs_change = last - first
    if first != 0:
        rel_change = abs_change / first
    else:
        rel_change = None

    cagr: Decimal | None = None
    if granularity == "year" and len(points) >= 2:
        n_years = points[-1].year - points[0].year
        if n_years > 0:
            cagr = _compute_cagr(first, last, n_years)

    return PartnerGrowth(
        reporter_code=reporter_code,
        partner_code=partner_code,
        points=tuple(points),
        absolute_change=abs_change,
        relative_change=rel_change,
        cagr=cagr,
    )

partner_balance

partner_balance(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    by: str = "trade_balance",
    descending: bool = True,
    limit: int | None = None,
) -> tuple[PartnerBalanceRow, ...]

Compute per-partner trade balance for one reporter.

Equivalent to top_partners(..., by="trade_balance") but typed separately (returns PartnerBalanceRow instead of PartnerRankingRow) so callers can opt into the balance view semantically.

Parameters

dataset The CanonicalDataset to analyse. reporter_code The reporter whose partners to summarise. by One of "trade_balance" (default), "abs_trade_balance", "total_trade", "exports", "imports", or "record_count". descending When True (default), largest values first. limit If supplied, return only the top limit rows (after sorting).

Returns

tuple[PartnerBalanceRow, ...] Sorted by by. Empty tuple when no partners match.

Source code in un_comtrade/analytics/partner.py
def partner_balance(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    by: str = "trade_balance",
    descending: bool = True,
    limit: int | None = None,
) -> tuple[PartnerBalanceRow, ...]:
    """Compute per-partner trade balance for one
    reporter.

    Equivalent to `top_partners(...,
    by="trade_balance")` but typed separately
    (returns `PartnerBalanceRow` instead of
    `PartnerRankingRow`) so callers can opt into
    the balance view semantically.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        The reporter whose partners to summarise.
    by
        One of `"trade_balance"` (default),
        `"abs_trade_balance"`, `"total_trade"`,
        `"exports"`, `"imports"`, or
        `"record_count"`.
    descending
        When `True` (default), largest values
        first.
    limit
        If supplied, return only the top `limit`
        rows (after sorting).

    Returns
    -------
    tuple[PartnerBalanceRow, ...]
        Sorted by `by`. Empty tuple when no
        partners match.
    """
    if by not in _PARTNER_RANKING_FIELDS:
        raise PartnerAnalyticsError(
            f"Unknown ranking field {by!r}; "
            f"valid: {sorted(_PARTNER_RANKING_FIELDS)}"
        )
    if limit is not None and limit < 0:
        raise PartnerAnalyticsError(
            "limit must be non-negative"
        )
    _check_canonical_dataset(dataset, fn_name="partner_balance")

    # Reuse top_partners' grouping logic.
    ranking = top_partners(
        dataset,
        reporter_code=reporter_code,
        flow=None,
        by=by,
        descending=descending,
        limit=limit,
    )

    # Re-shape as PartnerBalanceRow.
    return tuple(
        PartnerBalanceRow(
            partner_code=r.partner_code,
            partner_iso3=r.partner_iso3,
            partner_name=r.partner_name,
            total_exports=r.total_exports,
            total_imports=r.total_imports,
            trade_balance=r.trade_balance,
            total_trade=r.total_trade,
            record_count=r.record_count,
        )
        for r in ranking
    )

bilateral_summary

bilateral_summary(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    partner_code: int,
) -> BilateralSummary | None

Compute the bilateral summary for one reporter / partner pair.

Returns None when no records exist on either side.

Source code in un_comtrade/analytics/partner.py
def bilateral_summary(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    partner_code: int,
) -> BilateralSummary | None:
    """Compute the bilateral summary for one
    reporter / partner pair.

    Returns `None` when no records exist on
    either side.
    """
    _check_canonical_dataset(
        dataset, fn_name="bilateral_summary"
    )

    # Side A: reporter == reporter_code,
    # partner == partner_code.
    side_a = _select_records(
        dataset,
        reporter_code=reporter_code,
        partner_code=partner_code,
    )
    # Side B (mirror): reporter == partner_code,
    # partner == reporter_code.
    side_b = _select_records(
        dataset,
        reporter_code=partner_code,
        partner_code=reporter_code,
    )

    if not side_a and not side_b:
        return None

    a_x = _sum_primary_value(
        r for r in side_a if r.flow.flow_code == "X"
    )
    a_m = _sum_primary_value(
        r for r in side_a if r.flow.flow_code == "M"
    )
    b_x = _sum_primary_value(
        r for r in side_b if r.flow.flow_code == "X"
    )
    b_m = _sum_primary_value(
        r for r in side_b if r.flow.flow_code == "M"
    )

    total_exports = a_x + b_x
    total_imports = a_m + b_m
    total_trade = total_exports + total_imports

    # Pick metadata from whichever side had data
    # first; prefer side_a so the partner is
    # identified consistently.
    iso3: str | None = None
    name: str | None = None
    if side_a:
        iso3 = side_a[0].partner.iso3
        name = side_a[0].partner.name
    elif side_b:
        iso3 = side_b[0].partner.iso3
        name = side_b[0].partner.name

    years = [
        r.ref_year
        for r in (*side_a, *side_b)
    ]
    year_range: tuple[int, int] | None = (
        (min(years), max(years)) if years else None
    )

    return BilateralSummary(
        reporter_code=reporter_code,
        partner_code=partner_code,
        partner_iso3=iso3,
        partner_name=name,
        reporter_to_partner_exports=a_x,
        reporter_to_partner_imports=a_m,
        partner_to_reporter_exports=b_x,
        partner_to_reporter_imports=b_m,
        total_exports=total_exports,
        total_imports=total_imports,
        total_trade=total_trade,
        record_count=len(side_a) + len(side_b),
        year_range=year_range,
    )

commodity

Commodity / HS Analytics (P6-004).

This module is the third concrete analytics submodule built on top of AnalyticsEngine (P6-001). It provides four commodity-level analytics that operate exclusively on CanonicalDataset:

  • top_hs_codes(...) — rank HS codes (commodity_code) by trade value for a given reporter (or globally). Supports flow filter, HS-level filter (2/4/6 digit), and limit.
  • commodity_ranking(...) — rank commodities with optional share field (each commodity's percentage of the grand total).
  • commodity_trend(...) — time-series of trade for one HS code.
  • sector_summaries(...) — aggregate by HS section (the 21 WCO Harmonized System sections identified by Roman numerals), using the standard chapter-to-section mapping.

All monetary fields are Decimal (ADR-0027). All dataclasses are frozen=True (ADR-0013).

The module is decoupled from the transport layer: only stdlib + intra-package imports.

CommodityAnalyticsError

Bases: AnalyticsError

Raised when a commodity-level analytics operation cannot be performed.

Source code in un_comtrade/analytics/commodity.py
class CommodityAnalyticsError(AnalyticsError):
    """Raised when a commodity-level analytics
    operation cannot be performed."""

HSCodeRankingRow dataclass

One row of a commodity ranking.

Captures exports / imports / total trade / balance for a single HS code (or commodity code) plus the commodity name (if present in the source records).

Source code in un_comtrade/analytics/commodity.py
@dataclass(frozen=True)
class HSCodeRankingRow:
    """One row of a commodity ranking.

    Captures exports / imports / total trade /
    balance for a single HS code (or commodity
    code) plus the commodity name (if present in
    the source records).
    """

    commodity_code: str
    commodity_name: str | None
    total_exports: Decimal
    total_imports: Decimal
    total_trade: Decimal
    trade_balance: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "total_trade", "trade_balance",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise CommodityAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

CommodityRankingRow dataclass

One row of a commodity ranking with an optional share field (each commodity's percentage of the grand total trade).

share is in [0, 1]. When include_share=False (default), share is None.

Source code in un_comtrade/analytics/commodity.py
@dataclass(frozen=True)
class CommodityRankingRow:
    """One row of a commodity ranking with an
    optional `share` field (each commodity's
    percentage of the grand total trade).

    `share` is in [0, 1]. When
    `include_share=False` (default), `share`
    is `None`.
    """

    commodity_code: str
    commodity_name: str | None
    total_exports: Decimal
    total_imports: Decimal
    total_trade: Decimal
    trade_balance: Decimal
    record_count: int
    share: Decimal | None = None

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "total_trade", "trade_balance",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise CommodityAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )
        if self.share is not None and not isinstance(
            self.share, Decimal
        ):
            raise CommodityAnalyticsError(
                "share must be Decimal when set"
            )

CommodityTrendPoint dataclass

One point on a commodity trend (one year or one period).

Source code in un_comtrade/analytics/commodity.py
@dataclass(frozen=True)
class CommodityTrendPoint:
    """One point on a commodity trend (one year
    or one period)."""

    year: int
    period: str
    total_trade: Decimal
    exports: Decimal
    imports: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in ("total_trade", "exports", "imports"):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise CommodityAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

SectorSummaryRow dataclass

One row of the sector summary.

Captures totals per WCO Harmonized System section. The chapter_codes tuple lists the 2-digit chapter numbers that fall within this section.

Source code in un_comtrade/analytics/commodity.py
@dataclass(frozen=True)
class SectorSummaryRow:
    """One row of the sector summary.

    Captures totals per WCO Harmonized System
    section. The `chapter_codes` tuple lists the
    2-digit chapter numbers that fall within
    this section.
    """

    sector_id: str
    sector_name: str
    total_exports: Decimal
    total_imports: Decimal
    total_trade: Decimal
    trade_balance: Decimal
    record_count: int
    chapter_codes: tuple[int, ...] = field(
        default_factory=tuple
    )
    hs_code_count: int = 0

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "total_trade", "trade_balance",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise CommodityAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

sector_for_chapter

sector_for_chapter(chapter: int) -> tuple[str, str]

Return (section_id, section_name) for a 2-digit HS chapter code. Returns ("??", "Unknown") for chapters outside the standard WCO HS range (1-98).

Source code in un_comtrade/analytics/commodity.py
def sector_for_chapter(chapter: int) -> tuple[str, str]:
    """Return `(section_id, section_name)` for a
    2-digit HS chapter code. Returns
    `("??", "Unknown")` for chapters outside the
    standard WCO HS range (1-98)."""
    return _CHAPTER_TO_SECTOR.get(chapter, ("??", "Unknown"))

top_hs_codes

top_hs_codes(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
    by: str = "total_trade",
    descending: bool = True,
    limit: int | None = None,
    hs_level: int | None = None,
) -> tuple[HSCodeRankingRow, ...]

Rank HS codes by trade value.

Parameters

dataset The CanonicalDataset to analyse. reporter_code If supplied, only records with this reporter contribute. flow "X" keeps exports; "M" keeps imports; None (default) keeps both flows. by "total_trade" (default), "exports", "imports", "trade_balance", "abs_trade_balance", or "record_count". descending When True (default), largest first. limit If supplied, return only the top limit rows. hs_level If supplied (one of 2, 4, 6), keep only records whose commodity code has exactly that many leading digits. Useful for ranking at the HS section, HS heading, or HS subheading level.

Returns

tuple[HSCodeRankingRow, ...] Sorted by by. Empty when no records match.

Source code in un_comtrade/analytics/commodity.py
def top_hs_codes(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
    by: str = "total_trade",
    descending: bool = True,
    limit: int | None = None,
    hs_level: int | None = None,
) -> tuple[HSCodeRankingRow, ...]:
    """Rank HS codes by trade value.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        If supplied, only records with this
        reporter contribute.
    flow
        `"X"` keeps exports; `"M"` keeps imports;
        `None` (default) keeps both flows.
    by
        `"total_trade"` (default), `"exports"`,
        `"imports"`, `"trade_balance"`,
        `"abs_trade_balance"`, or `"record_count"`.
    descending
        When `True` (default), largest first.
    limit
        If supplied, return only the top `limit`
        rows.
    hs_level
        If supplied (one of `2`, `4`, `6`), keep
        only records whose commodity code has
        exactly that many leading digits. Useful
        for ranking at the HS section, HS
        heading, or HS subheading level.

    Returns
    -------
    tuple[HSCodeRankingRow, ...]
        Sorted by `by`. Empty when no records
        match.
    """
    if by not in _RANKING_FIELDS:
        raise CommodityAnalyticsError(
            f"Unknown ranking field {by!r}; "
            f"valid: {sorted(_RANKING_FIELDS)}"
        )
    if limit is not None and limit < 0:
        raise CommodityAnalyticsError(
            "limit must be non-negative"
        )
    if hs_level is not None and hs_level not in (2, 4, 6):
        raise CommodityAnalyticsError(
            "hs_level must be one of 2, 4, 6"
        )
    _check_canonical_dataset(dataset, fn_name="top_hs_codes")

    selected = []
    for record in dataset.records:
        if reporter_code is not None:
            if record.reporter.reporter_code != reporter_code:
                continue
        if flow is not None:
            if record.flow.flow_code != flow:
                continue
        if hs_level is not None:
            code = record.commodity.commodity_code
            # HS-level filter: keep only records
            # whose commodity code has EXACTLY
            # `hs_level` leading digits. A 6-digit
            # code is at the subheading level, NOT
            # at the chapter level.
            leading_digits = 0
            for c in code:
                if c.isdigit():
                    leading_digits += 1
                else:
                    break
            if leading_digits != hs_level:
                continue
        selected.append(record)
    return _aggregate_by_commodity(
        selected,
        by=by,
        descending=descending,
        limit=limit,
        flow=flow,
    )

commodity_ranking

commodity_ranking(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
    by: str = "total_trade",
    descending: bool = True,
    limit: int | None = None,
    hs_level: int | None = None,
    include_share: bool = False,
) -> tuple[CommodityRankingRow, ...]

Rank commodities with optional share.

Same shape as top_hs_codes(...) but with an optional share field (commodity.total_trade / grand_total_trade) that lets callers see each commodity's percentage of the grand total. Useful for concentration analysis (e.g. "top 5 commodities account for 60% of trade").

Parameters

include_share When True, attach a share field (in [0, 1]) to each row.

Source code in un_comtrade/analytics/commodity.py
def commodity_ranking(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
    by: str = "total_trade",
    descending: bool = True,
    limit: int | None = None,
    hs_level: int | None = None,
    include_share: bool = False,
) -> tuple[CommodityRankingRow, ...]:
    """Rank commodities with optional share.

    Same shape as `top_hs_codes(...)` but with
    an optional `share` field
    (`commodity.total_trade / grand_total_trade`)
    that lets callers see each commodity's
    percentage of the grand total. Useful for
    concentration analysis (e.g. "top 5
    commodities account for 60% of trade").

    Parameters
    ----------
    include_share
        When `True`, attach a `share` field
        (in [0, 1]) to each row.
    """
    if not isinstance(include_share, bool):
        raise CommodityAnalyticsError(
            "include_share must be a bool"
        )
    base_rows = top_hs_codes(
        dataset,
        reporter_code=reporter_code,
        flow=flow,
        by=by,
        descending=descending,
        limit=limit,
        hs_level=hs_level,
    )
    if not base_rows:
        return ()

    if not include_share:
        # Re-shape into CommodityRankingRow with
        # share=None.
        return tuple(
            CommodityRankingRow(
                commodity_code=r.commodity_code,
                commodity_name=r.commodity_name,
                total_exports=r.total_exports,
                total_imports=r.total_imports,
                total_trade=r.total_trade,
                trade_balance=r.trade_balance,
                record_count=r.record_count,
                share=None,
            )
            for r in base_rows
        )

    # Compute share relative to the dataset's
    # GRAND total (not the filtered subset, so
    # callers can compare across filters).
    grand_total = _sum_primary_value(
        r for r in dataset.records
        if reporter_code is None
        or r.reporter.reporter_code == reporter_code
    )
    rows: list[CommodityRankingRow] = []
    for r in base_rows:
        share: Decimal | None = None
        if grand_total != 0:
            share = r.total_trade / grand_total
        rows.append(
            CommodityRankingRow(
                commodity_code=r.commodity_code,
                commodity_name=r.commodity_name,
                total_exports=r.total_exports,
                total_imports=r.total_imports,
                total_trade=r.total_trade,
                trade_balance=r.trade_balance,
                record_count=r.record_count,
                share=share,
            )
        )
    return tuple(rows)

commodity_trend

commodity_trend(
    dataset: CanonicalDataset,
    *,
    commodity_code: str,
    reporter_code: int | None = None,
    granularity: str = "year",
) -> tuple[CommodityTrendPoint, ...]

Build a commodity trend for one HS code.

Parameters

dataset The CanonicalDataset to analyse. commodity_code The HS / commodity code to track (exact match against record.commodity.commodity_code). reporter_code If supplied, only records with this reporter contribute. granularity "year" (default) groups by ref_year; "period" groups by period string.

Returns

tuple[CommodityTrendPoint, ...] Sorted by (year, period). Empty when no records match.

Source code in un_comtrade/analytics/commodity.py
def commodity_trend(
    dataset: CanonicalDataset,
    *,
    commodity_code: str,
    reporter_code: int | None = None,
    granularity: str = "year",
) -> tuple[CommodityTrendPoint, ...]:
    """Build a commodity trend for one HS code.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    commodity_code
        The HS / commodity code to track (exact
        match against `record.commodity.commodity_code`).
    reporter_code
        If supplied, only records with this
        reporter contribute.
    granularity
        `"year"` (default) groups by `ref_year`;
        `"period"` groups by `period` string.

    Returns
    -------
    tuple[CommodityTrendPoint, ...]
        Sorted by `(year, period)`. Empty when
        no records match.
    """
    if not commodity_code:
        raise CommodityAnalyticsError(
            "commodity_code must be a non-empty string"
        )
    if granularity not in ("year", "period"):
        raise CommodityAnalyticsError(
            f"Unknown granularity {granularity!r}; "
            f"valid: 'year', 'period'"
        )
    _check_canonical_dataset(dataset, fn_name="commodity_trend")

    selected = []
    for r in dataset.records:
        if r.commodity.commodity_code != commodity_code:
            continue
        if reporter_code is not None:
            if r.reporter.reporter_code != reporter_code:
                continue
        selected.append(r)
    if not selected:
        return ()

    bucket: dict[tuple[int, str], list] = {}
    for r in selected:
        key = (r.ref_year, r.period)
        bucket.setdefault(key, []).append(r)

    points: list[CommodityTrendPoint] = []
    for (year, period), group in bucket.items():
        x = _sum_primary_value(
            r for r in group if r.flow.flow_code == "X"
        )
        m = _sum_primary_value(
            r for r in group if r.flow.flow_code == "M"
        )
        points.append(
            CommodityTrendPoint(
                year=year,
                period=period,
                total_trade=x + m,
                exports=x,
                imports=m,
                record_count=len(group),
            )
        )
    points.sort(key=lambda p: (p.year, p.period))
    return tuple(points)

sector_summaries

sector_summaries(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
) -> tuple[SectorSummaryRow, ...]

Build sector summaries (per WCO HS section).

Parameters

dataset The CanonicalDataset to analyse. reporter_code If supplied, only records with this reporter contribute. flow "X" keeps exports; "M" keeps imports; None (default) keeps both flows.

Returns

tuple[SectorSummaryRow, ...] One row per WCO HS section (21 sections plus an "Unknown" pseudo-section for commodity codes outside the HS range). Sections with zero records are still included (with zero totals) so callers can render a complete matrix.

Source code in un_comtrade/analytics/commodity.py
def sector_summaries(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
) -> tuple[SectorSummaryRow, ...]:
    """Build sector summaries (per WCO HS section).

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        If supplied, only records with this
        reporter contribute.
    flow
        `"X"` keeps exports; `"M"` keeps imports;
        `None` (default) keeps both flows.

    Returns
    -------
    tuple[SectorSummaryRow, ...]
        One row per WCO HS section (21 sections
        plus an "Unknown" pseudo-section for
        commodity codes outside the HS range).
        Sections with zero records are still
        included (with zero totals) so callers
        can render a complete matrix.
    """
    _check_canonical_dataset(dataset, fn_name="sector_summaries")

    selected = []
    for r in dataset.records:
        if reporter_code is not None:
            if r.reporter.reporter_code != reporter_code:
                continue
        if flow is not None:
            if r.flow.flow_code != flow:
                continue
        selected.append(r)

    # Group by sector_id.
    by_sector_x: dict[str, Decimal] = {}
    by_sector_m: dict[str, Decimal] = {}
    by_sector_chapters: dict[str, set[int]] = {}
    by_sector_codes: dict[str, set[str]] = {}
    counts: dict[str, int] = {}
    sector_meta: dict[str, str] = {}

    # F-002: bucket records by sector_id so the
    # per-flow Decimal sums can be delegated to
    # `_q_summarize(...)` (the internal Query Engine
    # aggregation primitive). The buckets themselves
    # are NOT aggregations — they are pre-grouping
    # routing structures; the actual Decimal
    # summation is performed by the Query Engine.
    buckets_x: dict[str, list] = {}
    buckets_m: dict[str, list] = {}

    for record in selected:
        chapter = _hs_chapter(record.commodity.commodity_code)
        if chapter is None:
            section_id, section_name = "??", "Unknown"
        else:
            section_id, section_name = sector_for_chapter(chapter)
        if section_id not in sector_meta:
            sector_meta[section_id] = section_name
        by_sector_chapters.setdefault(section_id, set()).add(chapter or 0)
        by_sector_codes.setdefault(section_id, set()).add(
            record.commodity.commodity_code
        )
        counts[section_id] = counts.get(section_id, 0) + 1
        if record.flow.flow_code == "X":
            buckets_x.setdefault(section_id, []).append(record)
        elif record.flow.flow_code == "M":
            buckets_m.setdefault(section_id, []).append(record)

    # F-002: delegate the per-sector per-flow
    # Decimal summation to the Query Engine
    # `summarize(...)` primitive.
    for section_id, bucket in buckets_x.items():
        s = summarize(
            tuple(bucket), field="trade_value.primary_value"
        )
        by_sector_x[section_id] = (
            s.sum if s.sum is not None else Decimal("0")
        )
    for section_id, bucket in buckets_m.items():
        s = summarize(
            tuple(bucket), field="trade_value.primary_value"
        )
        by_sector_m[section_id] = (
            s.sum if s.sum is not None else Decimal("0")
        )

    # Counter-flow zeroing.
    if flow == "X":
        for code in by_sector_m:
            by_sector_m[code] = Decimal("0")
    elif flow == "M":
        for code in by_sector_x:
            by_sector_x[code] = Decimal("0")

    # Build rows in section order (Roman
    # numerals I..XXI then "??") so callers see a
    # stable order.
    section_order = [sid for sid, _, _ in SECTORS] + ["??"]
    rows: list[SectorSummaryRow] = []
    for section_id in section_order:
        x = by_sector_x.get(section_id, Decimal("0"))
        m = by_sector_m.get(section_id, Decimal("0"))
        rows.append(
            SectorSummaryRow(
                sector_id=section_id,
                sector_name=sector_meta.get(
                    section_id,
                    next(
                        sname for sid, sname, _ in SECTORS
                        if sid == section_id
                    ),
                ) if section_id != "??" else "Unknown",
                total_exports=x,
                total_imports=m,
                total_trade=x + m,
                trade_balance=x - m,
                record_count=counts.get(section_id, 0),
                chapter_codes=tuple(
                    sorted(by_sector_chapters.get(section_id, set()))
                ),
                hs_code_count=len(
                    by_sector_codes.get(section_id, set())
                ),
            )
        )
    return tuple(rows)

timeseries

Time-series analytics (P6-005).

This module is the fourth concrete analytics submodule built on top of the AnalyticsEngine foundation (P6-001). It provides five time-series analytics that operate exclusively on CanonicalDataset:

  • annual_trend(...) — yearly time-series of sum_primary_value() (or any user-supplied metric) for a reporter / partner / commodity, with optional flow filter.
  • monthly_trend(...) — same shape but bucketed per month (UN Comtrade periods "202201".."202212" are parsed for year + month; pure-year periods like "2022" are excluded).
  • rolling_average(points, *, window=3) — rolling mean over a window of n points applied to any time-series of TrendPoints.
  • cagr(points, *, field="value") — compound annual growth rate between the first and last point of a series.
  • growth_rates(points, *, field="value") — per-point period-over-period growth rates (relative change).

All monetary fields are Decimal (ADR-0027). All dataclasses are frozen=True (ADR-0013).

The module is decoupled from the transport layer (same constraint as AnalyticsEngine): only stdlib + intra-package imports.

TimeSeriesAnalyticsError

Bases: AnalyticsError

Raised when a time-series analytics operation cannot be performed.

Source code in un_comtrade/analytics/timeseries.py
class TimeSeriesAnalyticsError(AnalyticsError):
    """Raised when a time-series analytics
    operation cannot be performed."""

TrendPoint dataclass

One point on a time-series trend.

year is the calendar year. period is the canonical period string from the source record ("2022", "202201", etc.). value is the metric value at this point. For monthly trends month is set; for annual trends it's None.

record_count is the number of source records that contributed to this point.

Source code in un_comtrade/analytics/timeseries.py
@dataclass(frozen=True)
class TrendPoint:
    """One point on a time-series trend.

    `year` is the calendar year. `period` is the
    canonical period string from the source
    record (`"2022"`, `"202201"`, etc.). `value`
    is the metric value at this point. For
    monthly trends `month` is set; for annual
    trends it's `None`.

    `record_count` is the number of source
    records that contributed to this point.
    """

    year: int
    period: str
    value: Decimal
    record_count: int
    month: int | None = None

    def __post_init__(self) -> None:
        if not isinstance(self.value, Decimal):
            raise TimeSeriesAnalyticsError(
                f"value must be Decimal; got {type(self.value).__name__}"
            )
        if not isinstance(self.year, int):
            raise TimeSeriesAnalyticsError(
                f"year must be int; got {type(self.year).__name__}"
            )
        if self.month is not None and (
            not isinstance(self.month, int)
            or self.month < 1
            or self.month > 12
        ):
            raise TimeSeriesAnalyticsError(
                f"month must be int in 1..12 or None; got {self.month!r}"
            )

GrowthRatePoint dataclass

Per-point growth rate observation.

growth is (current - previous) / previous as a fraction. previous is None for the first point (no prior value to compare against).

Source code in un_comtrade/analytics/timeseries.py
@dataclass(frozen=True)
class GrowthRatePoint:
    """Per-point growth rate observation.

    `growth` is `(current - previous) / previous`
    as a fraction. `previous` is `None` for the
    first point (no prior value to compare
    against).
    """

    year: int
    period: str
    value: Decimal
    previous: Decimal | None
    growth: Decimal | None
    record_count: int
    month: int | None = None

    def __post_init__(self) -> None:
        for f in ("value", "previous", "growth"):
            v = getattr(self, f)
            if v is not None and not isinstance(v, Decimal):
                raise TimeSeriesAnalyticsError(
                    f"{f} must be Decimal or None; got {type(v).__name__}"
                )

annual_trend

annual_trend(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
    partner_code: int | None = None,
    commodity_code: str | None = None,
    metric: Metric | None = None,
) -> tuple[TrendPoint, ...]

Build an annual time-series trend.

Parameters

dataset The CanonicalDataset to analyse. reporter_code, partner_code, flow, commodity_code Optional filters (any combination). metric The Metric to compute per year. Defaults to Metric.sum_primary_value() (total trade value).

Returns

tuple[TrendPoint, ...] Sorted ascending by year. Empty when no records match.

Source code in un_comtrade/analytics/timeseries.py
def annual_trend(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
    partner_code: int | None = None,
    commodity_code: str | None = None,
    metric: Metric | None = None,
) -> tuple[TrendPoint, ...]:
    """Build an annual time-series trend.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code, partner_code, flow,
    commodity_code
        Optional filters (any combination).
    metric
        The `Metric` to compute per year.
        Defaults to `Metric.sum_primary_value()`
        (total trade value).

    Returns
    -------
    tuple[TrendPoint, ...]
        Sorted ascending by year. Empty when no
        records match.
    """
    _check_canonical_dataset(dataset, fn_name="annual_trend")
    m = _coerce_metric(metric) if metric is not None else _metric_for_sum()
    selected = _select_records(
        dataset,
        reporter_code=reporter_code,
        partner_code=partner_code,
        flow_code=flow,
        commodity_code=commodity_code,
    )
    if not selected:
        return ()
    buckets = _bucket_records(selected, granularity="year")
    points: list[TrendPoint] = []
    for (year, _month), group in buckets.items():
        group_dataset = CanonicalDataset(
            name=dataset.name,
            records=tuple(group),
            schema_version=dataset.schema_version,
            extracted_at=dataset.extracted_at,
            parser_name=dataset.parser_name,
            skipped=0,
            duplicates_removed=0,
            source_count=len(group),
            metadata=dict(dataset.metadata),
        )
        value = m.compute(group_dataset)
        points.append(
            TrendPoint(
                year=year,
                period=str(year),
                value=_to_decimal(value),
                record_count=len(group),
            )
        )
    points.sort(key=lambda p: p.year)
    return tuple(points)

monthly_trend

monthly_trend(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
    partner_code: int | None = None,
    commodity_code: str | None = None,
    metric: Metric | None = None,
) -> tuple[TrendPoint, ...]

Build a monthly time-series trend.

Same shape as annual_trend(...) but bucketed per month. Records with annual-only period strings (e.g. "2022") are excluded because they cannot be mapped to a specific month.

Source code in un_comtrade/analytics/timeseries.py
def monthly_trend(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    flow: str | None = None,
    partner_code: int | None = None,
    commodity_code: str | None = None,
    metric: Metric | None = None,
) -> tuple[TrendPoint, ...]:
    """Build a monthly time-series trend.

    Same shape as `annual_trend(...)` but
    bucketed per month. Records with annual-only
    period strings (e.g. `"2022"`) are excluded
    because they cannot be mapped to a specific
    month.
    """
    _check_canonical_dataset(dataset, fn_name="monthly_trend")
    m = _coerce_metric(metric) if metric is not None else _metric_for_sum()
    selected = _select_records(
        dataset,
        reporter_code=reporter_code,
        partner_code=partner_code,
        flow_code=flow,
        commodity_code=commodity_code,
    )
    if not selected:
        return ()
    buckets = _bucket_records(selected, granularity="month")
    points: list[TrendPoint] = []
    for (year, month), group in buckets.items():
        assert month is not None
        group_dataset = CanonicalDataset(
            name=dataset.name,
            records=tuple(group),
            schema_version=dataset.schema_version,
            extracted_at=dataset.extracted_at,
            parser_name=dataset.parser_name,
            skipped=0,
            duplicates_removed=0,
            source_count=len(group),
            metadata=dict(dataset.metadata),
        )
        value = m.compute(group_dataset)
        points.append(
            TrendPoint(
                year=year,
                period=f"{year}{month:02d}",
                value=_to_decimal(value),
                record_count=len(group),
                month=month,
            )
        )
    points.sort(key=lambda p: (p.year, p.month))
    return tuple(points)

rolling_average

rolling_average(
    points: Sequence[TrendPoint],
    *,
    window: int = 3,
    field: str = "value",
) -> tuple[TrendPoint, ...]

Compute the rolling average of a time- series over a window.

At each index i, the output point's field value is the mean of the input field values from max(0, i - window + 1) through i (inclusive — i.e., a trailing window). The first window - 1 output points are based on a partial window (e.g. for window=3, index 0 uses just point 0, index 1 uses points 0–1, etc.).

Parameters

points Input time-series. Should be sorted by (year, period). window Number of consecutive points to average. Default 3. field Name of the dataclass attribute to average. Default "value".

Returns

tuple[TrendPoint, ...] Same length as the input. Each point's field is replaced with the rolling average; all other attributes are preserved from the input point.

Source code in un_comtrade/analytics/timeseries.py
def rolling_average(
    points: Sequence[TrendPoint],
    *,
    window: int = 3,
    field: str = "value",
) -> tuple[TrendPoint, ...]:
    """Compute the rolling average of a time-
    series over a window.

    At each index `i`, the output point's
    `field` value is the mean of the input
    `field` values from `max(0, i - window + 1)`
    through `i` (inclusive — i.e., a
    trailing window). The first `window - 1`
    output points are based on a partial
    window (e.g. for `window=3`, index 0 uses
    just point 0, index 1 uses points 0–1, etc.).

    Parameters
    ----------
    points
        Input time-series. Should be sorted by
        `(year, period)`.
    window
        Number of consecutive points to average.
        Default `3`.
    field
        Name of the dataclass attribute to
        average. Default `"value"`.

    Returns
    -------
    tuple[TrendPoint, ...]
        Same length as the input. Each point's
        `field` is replaced with the rolling
        average; all other attributes are
        preserved from the input point.
    """
    if window < 1:
        raise TimeSeriesAnalyticsError(
            "window must be at least 1"
        )
    if not points:
        return ()
    if not all(isinstance(p, TrendPoint) for p in points):
        raise TimeSeriesAnalyticsError(
            "points must be a sequence of TrendPoint"
        )

    # Extract the field values, preserving the
    # original indices.
    raw_values: list[Decimal] = []
    for p in points:
        v = getattr(p, field)
        if not isinstance(v, (Decimal, int, float)):
            raise TimeSeriesAnalyticsError(
                f"point.{field} must be numeric; got {type(v).__name__}"
            )
        raw_values.append(_to_decimal(v))

    # Compute rolling mean.
    result: list[TrendPoint] = []
    for i, point in enumerate(points):
        lo = max(0, i - window + 1)
        window_values = raw_values[lo:i + 1]
        avg = sum(window_values, start=Decimal("0")) / Decimal(
            len(window_values)
        )
        result.append(replace(point, **{field: avg}))
    return tuple(result)

cagr

cagr(
    points: Sequence[TrendPoint],
    *,
    field: str = "value",
    years: int | None = None,
) -> Decimal | None

Compute the Compound Annual Growth Rate between the first and last point of a series.

Parameters

points Input time-series (sorted ascending). field Dataclass attribute to use. Default "value". years Override for the time span (in years). When None, derived from the year difference between the first and last points.

Returns

Decimal | None The CAGR as a fraction (e.g. Decimal("0.5") for 50 % annual growth). Returns None when the calculation is undefined (zero / negative first, no span, or fewer than 2 points).

Source code in un_comtrade/analytics/timeseries.py
def cagr(
    points: Sequence[TrendPoint],
    *,
    field: str = "value",
    years: int | None = None,
) -> Decimal | None:
    """Compute the Compound Annual Growth Rate
    between the first and last point of a
    series.

    Parameters
    ----------
    points
        Input time-series (sorted ascending).
    field
        Dataclass attribute to use. Default
        `"value"`.
    years
        Override for the time span (in years).
        When `None`, derived from the
        `year` difference between the first
        and last points.

    Returns
    -------
    Decimal | None
        The CAGR as a fraction (e.g.
        `Decimal("0.5")` for 50 % annual growth).
        Returns `None` when the calculation is
        undefined (zero / negative first, no
        span, or fewer than 2 points).
    """
    if len(points) < 2:
        return None
    if not all(isinstance(p, TrendPoint) for p in points):
        raise TimeSeriesAnalyticsError(
            "points must be a sequence of TrendPoint"
        )

    first = _to_decimal(getattr(points[0], field))
    last = _to_decimal(getattr(points[-1], field))

    if years is None:
        years = points[-1].year - points[0].year
    if years <= 0:
        return None
    if first == 0:
        if last == 0:
            return Decimal("0")
        return None
    if first < 0:
        return None
    try:
        ratio = float(last) / float(first)
        if ratio <= 0:
            return None
        return Decimal(str(ratio ** (1.0 / years) - 1))
    except (ValueError, ZeroDivisionError, OverflowError):
        return None

growth_rates

growth_rates(
    points: Sequence[TrendPoint], *, field: str = "value"
) -> tuple[GrowthRatePoint, ...]

Compute period-over-period growth rates.

For each point i ≥ 1, the growth is (value[i] - value[i-1]) / value[i-1]. For i = 0, growth is None (no prior value).

Parameters

points Input time-series (sorted ascending). field Dataclass attribute to compare. Default "value".

Returns

tuple[GrowthRatePoint, ...] One row per input point. previous is None for the first row.

Source code in un_comtrade/analytics/timeseries.py
def growth_rates(
    points: Sequence[TrendPoint],
    *,
    field: str = "value",
) -> tuple[GrowthRatePoint, ...]:
    """Compute period-over-period growth rates.

    For each point `i ≥ 1`, the `growth` is
    `(value[i] - value[i-1]) / value[i-1]`.
    For `i = 0`, `growth` is `None` (no prior
    value).

    Parameters
    ----------
    points
        Input time-series (sorted ascending).
    field
        Dataclass attribute to compare. Default
        `"value"`.

    Returns
    -------
    tuple[GrowthRatePoint, ...]
        One row per input point. `previous`
        is `None` for the first row.
    """
    if not points:
        return ()
    if not all(isinstance(p, TrendPoint) for p in points):
        raise TimeSeriesAnalyticsError(
            "points must be a sequence of TrendPoint"
        )

    result: list[GrowthRatePoint] = []
    prev_value: Decimal | None = None
    for point in points:
        value = _to_decimal(getattr(point, field))
        if prev_value is None:
            result.append(
                GrowthRatePoint(
                    year=point.year,
                    period=point.period,
                    value=value,
                    previous=None,
                    growth=None,
                    record_count=point.record_count,
                    month=point.month,
                )
            )
        else:
            if prev_value == 0:
                growth: Decimal | None = None
            else:
                growth = (value - prev_value) / prev_value
            result.append(
                GrowthRatePoint(
                    year=point.year,
                    period=point.period,
                    value=value,
                    previous=prev_value,
                    growth=growth,
                    record_count=point.record_count,
                    month=point.month,
                )
            )
        prev_value = value
    return tuple(result)

balance

Trade-balance analytics (P6-006).

This module is the fifth concrete analytics submodule built on top of the AnalyticsEngine foundation (P6-001). It provides four trade-balance analytics that operate exclusively on CanonicalDataset:

  • country_balance(...) — exports minus imports aggregated per reporter (country). With reporter_code=None, returns balance for ALL reporters (effectively a per-country breakdown of the global balance).
  • partner_trade_balance(...) — exports minus imports aggregated per partner for one reporter.
  • commodity_balance(...) — exports minus imports aggregated per HS code for one reporter (or globally when reporter_code is None).
  • global_balance(...) — global trade balance across all reporters, all partners, all commodities (single BalanceSummary).

All monetary fields are Decimal (ADR-0027). All dataclasses are frozen=True (ADR-0013).

The module is decoupled from the transport layer (same constraint as AnalyticsEngine): only stdlib + intra-package imports.

PartnerBalanceRow dataclass

One row of a partner balance view.

Sibling of PartnerRankingRow — kept as a separate type so callers can opt into the balance view semantically.

Source code in un_comtrade/analytics/partner.py
@dataclass(frozen=True)
class PartnerBalanceRow:
    """One row of a partner balance view.

    Sibling of `PartnerRankingRow` — kept as a
    separate type so callers can opt into the
    balance view semantically.
    """

    partner_code: int
    partner_iso3: str | None
    partner_name: str | None
    total_exports: Decimal
    total_imports: Decimal
    trade_balance: Decimal
    total_trade: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "trade_balance", "total_trade",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise PartnerAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

BalanceAnalyticsError

Bases: AnalyticsError

Raised when a balance analytics operation cannot be performed.

Source code in un_comtrade/analytics/balance.py
class BalanceAnalyticsError(AnalyticsError):
    """Raised when a balance analytics operation
    cannot be performed."""

BalanceSummary dataclass

A single-snapshot trade balance summary.

trade_balance = total_exports - total_imports. total_trade = total_exports + total_imports.

Source code in un_comtrade/analytics/balance.py
@dataclass(frozen=True)
class BalanceSummary:
    """A single-snapshot trade balance summary.

    `trade_balance = total_exports - total_imports`.
    `total_trade = total_exports + total_imports`.
    """

    total_exports: Decimal
    total_imports: Decimal
    trade_balance: Decimal
    total_trade: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "trade_balance", "total_trade",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise BalanceAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

CountryBalanceRow dataclass

One row of the country balance breakdown.

Source code in un_comtrade/analytics/balance.py
@dataclass(frozen=True)
class CountryBalanceRow:
    """One row of the country balance breakdown."""

    reporter_code: int
    reporter_iso3: str | None
    reporter_name: str | None
    total_exports: Decimal
    total_imports: Decimal
    trade_balance: Decimal
    total_trade: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "trade_balance", "total_trade",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise BalanceAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

CommodityBalanceRow dataclass

One row of the commodity balance breakdown.

Source code in un_comtrade/analytics/balance.py
@dataclass(frozen=True)
class CommodityBalanceRow:
    """One row of the commodity balance breakdown."""

    commodity_code: str
    commodity_name: str | None
    total_exports: Decimal
    total_imports: Decimal
    trade_balance: Decimal
    total_trade: Decimal
    record_count: int

    def __post_init__(self) -> None:
        for f in (
            "total_exports", "total_imports",
            "trade_balance", "total_trade",
        ):
            v = getattr(self, f)
            if not isinstance(v, Decimal):
                raise BalanceAnalyticsError(
                    f"{f} must be Decimal; got {type(v).__name__}"
                )

country_balance

country_balance(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> tuple[CountryBalanceRow, ...]

Compute the trade balance per reporter (country).

Parameters

dataset The CanonicalDataset to analyse. reporter_code If supplied, restrict to this single reporter. The result is then a zero-or-one-element tuple (zero when no records match). descending When True (default), the largest balances first. limit If supplied, return only the top limit rows.

Returns

tuple[CountryBalanceRow, ...] Sorted by trade_balance (descending by default). Empty when no records match.

Source code in un_comtrade/analytics/balance.py
def country_balance(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> tuple[CountryBalanceRow, ...]:
    """Compute the trade balance per reporter
    (country).

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        If supplied, restrict to this single
        reporter. The result is then a
        zero-or-one-element tuple (zero when no
        records match).
    descending
        When `True` (default), the largest
        balances first.
    limit
        If supplied, return only the top `limit`
        rows.

    Returns
    -------
    tuple[CountryBalanceRow, ...]
        Sorted by `trade_balance` (descending by
        default). Empty when no records match.
    """
    if limit is not None and limit < 0:
        raise BalanceAnalyticsError("limit must be non-negative")
    _check_canonical_dataset(dataset, fn_name="country_balance")

    selected = _select_records(dataset, reporter_code=reporter_code)
    if not selected:
        return ()

    # F-002: per-flow per-reporter aggregation routed
    # through the internal Query Engine (group_by +
    # summarize). The hand-rolled `dict.get(...)` +
    # `+ v` pattern has been retired.
    by_reporter_x = _sum_primary_by_group(
        selected, flow_code="X", group_field="reporter.reporter_code"
    )
    by_reporter_m = _sum_primary_by_group(
        selected, flow_code="M", group_field="reporter.reporter_code"
    )

    # Metadata (iso3 / name) is still collected per-
    # record because the Query Engine does not yet
    # support multi-attribute grouping; this is a
    # pure lookup, not an aggregation.
    meta: dict[int, dict[str, Any]] = {}
    counts: dict[int, int] = {}
    for record in selected:
        code = record.reporter.reporter_code
        if code not in meta:
            meta[code] = {
                "iso3": record.reporter.iso3,
                "name": record.reporter.name,
            }
        counts[code] = counts.get(code, 0) + 1

    rows: list[CountryBalanceRow] = []
    for code in sorted(counts):
        x = by_reporter_x.get(code, Decimal("0"))
        m = by_reporter_m.get(code, Decimal("0"))
        rows.append(
            CountryBalanceRow(
                reporter_code=code,
                reporter_iso3=meta[code].get("iso3"),
                reporter_name=meta[code].get("name"),
                total_exports=x,
                total_imports=m,
                trade_balance=x - m,
                total_trade=x + m,
                record_count=counts[code],
            )
        )
    rows.sort(key=lambda r: r.trade_balance, reverse=descending)
    if limit is not None:
        rows = rows[:limit]
    return tuple(rows)

partner_trade_balance

partner_trade_balance(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    descending: bool = True,
    limit: int | None = None,
) -> tuple[PartnerBalanceRow, ...]

Compute the trade balance per partner for one reporter.

Parameters

dataset The CanonicalDataset to analyse. reporter_code The reporter whose partners to rank. descending When True (default), the largest balances first. limit If supplied, return only the top limit rows.

Returns

tuple[PartnerBalanceRow, ...] Sorted by trade_balance (descending by default). Empty when no records match.

Notes

Named partner_trade_balance (not partner_balance) to disambiguate from partner.partner_balance in P6-003, which has a different signature (by=...) and a different shape (per-partner ranking keyed by any sortable field, not strictly trade_balance).

Source code in un_comtrade/analytics/balance.py
def partner_trade_balance(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    descending: bool = True,
    limit: int | None = None,
) -> tuple[PartnerBalanceRow, ...]:
    """Compute the trade balance per partner for
    one reporter.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        The reporter whose partners to rank.
    descending
        When `True` (default), the largest
        balances first.
    limit
        If supplied, return only the top `limit`
        rows.

    Returns
    -------
    tuple[PartnerBalanceRow, ...]
        Sorted by `trade_balance` (descending by
        default). Empty when no records match.

    Notes
    -----
    Named `partner_trade_balance` (not
    `partner_balance`) to disambiguate from
    `partner.partner_balance` in P6-003, which
    has a different signature (`by=...`) and a
    different shape (per-partner ranking keyed
    by any sortable field, not strictly
    `trade_balance`).
    """
    if limit is not None and limit < 0:
        raise BalanceAnalyticsError("limit must be non-negative")
    _check_canonical_dataset(dataset, fn_name="partner_trade_balance")

    selected = _select_records(dataset, reporter_code=reporter_code)
    if not selected:
        return ()

    # F-002: per-flow per-partner aggregation routed
    # through the internal Query Engine (group_by +
    # summarize). The hand-rolled `dict.get(...)` +
    # `+ v` pattern has been retired.
    by_partner_x = _sum_primary_by_group(
        selected, flow_code="X", group_field="partner.partner_code"
    )
    by_partner_m = _sum_primary_by_group(
        selected, flow_code="M", group_field="partner.partner_code"
    )

    # Metadata (iso3 / name) is still collected per-
    # record because the Query Engine does not yet
    # support multi-attribute grouping; this is a
    # pure lookup, not an aggregation.
    meta: dict[int, dict[str, Any]] = {}
    counts: dict[int, int] = {}
    for record in selected:
        code = record.partner.partner_code
        if code not in meta:
            meta[code] = {
                "iso3": record.partner.iso3,
                "name": record.partner.name,
            }
        counts[code] = counts.get(code, 0) + 1

    rows: list[PartnerBalanceRow] = []
    for code in sorted(counts):
        x = by_partner_x.get(code, Decimal("0"))
        m = by_partner_m.get(code, Decimal("0"))
        rows.append(
            PartnerBalanceRow(
                partner_code=code,
                partner_iso3=meta[code].get("iso3"),
                partner_name=meta[code].get("name"),
                total_exports=x,
                total_imports=m,
                trade_balance=x - m,
                total_trade=x + m,
                record_count=counts[code],
            )
        )
    rows.sort(key=lambda r: r.trade_balance, reverse=descending)
    if limit is not None:
        rows = rows[:limit]
    return tuple(rows)

commodity_balance

commodity_balance(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> tuple[CommodityBalanceRow, ...]

Compute the trade balance per commodity (HS code).

Parameters

dataset The CanonicalDataset to analyse. reporter_code If supplied, restrict to this reporter's trades. When None (default), aggregate across all reporters (a global per- commodity breakdown). descending When True (default), the largest balances first. limit If supplied, return only the top limit rows.

Returns

tuple[CommodityBalanceRow, ...] Sorted by trade_balance (descending by default). Empty when no records match.

Source code in un_comtrade/analytics/balance.py
def commodity_balance(
    dataset: CanonicalDataset,
    *,
    reporter_code: int | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> tuple[CommodityBalanceRow, ...]:
    """Compute the trade balance per commodity
    (HS code).

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        If supplied, restrict to this reporter's
        trades. When `None` (default), aggregate
        across all reporters (a global per-
        commodity breakdown).
    descending
        When `True` (default), the largest
        balances first.
    limit
        If supplied, return only the top `limit`
        rows.

    Returns
    -------
    tuple[CommodityBalanceRow, ...]
        Sorted by `trade_balance` (descending by
        default). Empty when no records match.
    """
    if limit is not None and limit < 0:
        raise BalanceAnalyticsError("limit must be non-negative")
    _check_canonical_dataset(dataset, fn_name="commodity_balance")

    selected = _select_records(dataset, reporter_code=reporter_code)
    if not selected:
        return ()

    # F-002: per-flow per-commodity aggregation routed
    # through the internal Query Engine (group_by +
    # summarize). The hand-rolled `dict.get(...)` +
    # `+ v` pattern has been retired.
    by_code_x = _sum_primary_by_group(
        selected, flow_code="X", group_field="commodity.commodity_code"
    )
    by_code_m = _sum_primary_by_group(
        selected, flow_code="M", group_field="commodity.commodity_code"
    )

    # Metadata (commodity name) is still collected per-
    # record because the Query Engine does not yet
    # support multi-attribute grouping; this is a
    # pure lookup, not an aggregation.
    meta: dict[str, str | None] = {}
    counts: dict[str, int] = {}
    for record in selected:
        code = record.commodity.commodity_code
        if code not in meta:
            meta[code] = record.commodity.name
        counts[code] = counts.get(code, 0) + 1

    rows: list[CommodityBalanceRow] = []
    for code in sorted(counts):
        x = by_code_x.get(code, Decimal("0"))
        m = by_code_m.get(code, Decimal("0"))
        rows.append(
            CommodityBalanceRow(
                commodity_code=code,
                commodity_name=meta[code],
                total_exports=x,
                total_imports=m,
                trade_balance=x - m,
                total_trade=x + m,
                record_count=counts[code],
            )
        )
    rows.sort(key=lambda r: r.trade_balance, reverse=descending)
    if limit is not None:
        rows = rows[:limit]
    return tuple(rows)

global_balance

global_balance(dataset: CanonicalDataset) -> BalanceSummary

Compute the global trade balance across ALL reporters, ALL partners, ALL commodities and ALL flows.

Returns a single BalanceSummary with total_exports, total_imports, trade_balance (= exports - imports), total_trade (= exports + imports), and record_count.

The flow classification is exhaustive: any record whose flow.flow_code is not "X" (export) is counted as an import. This matches UN Comtrade's two-flow model.

Returns a BalanceSummary with all zero values when the dataset is empty (the caller can detect an empty dataset via record_count == 0).

Source code in un_comtrade/analytics/balance.py
def global_balance(dataset: CanonicalDataset) -> BalanceSummary:
    """Compute the global trade balance across
    ALL reporters, ALL partners, ALL commodities
    and ALL flows.

    Returns a single `BalanceSummary` with
    `total_exports`, `total_imports`,
    `trade_balance` (= exports - imports),
    `total_trade` (= exports + imports), and
    `record_count`.

    The flow classification is exhaustive: any
    record whose `flow.flow_code` is not "X"
    (export) is counted as an import. This
    matches UN Comtrade's two-flow model.

    Returns a `BalanceSummary` with all zero
    values when the dataset is empty (the
    caller can detect an empty dataset via
    `record_count == 0`).
    """
    _check_canonical_dataset(dataset, fn_name="global_balance")
    return _build_balance_summary(dataset.records)

compare

Comparative analytics (P6-007).

This module is the sixth concrete analytics submodule built on top of the AnalyticsEngine foundation (P6-001). It provides four "side-by-side" comparison analytics that operate exclusively on CanonicalDataset:

  • country_vs_country(...) — compare trade profiles of two or more reporters.
  • year_vs_year(...) — compare the same reporter's trade between two periods.
  • commodity_vs_commodity(...) — compare two or more commodities (HS codes).
  • partner_vs_partner(...) — compare two or more partners for one reporter.

All four produce a common shape so callers can swap comparisons without rewriting downstream code:

ComparisonRow(
    dimension_key=...,
    dimension_label=...,
    values=(v1, v2, ...),       # one per side
    deltas=(d1, d2, ...),       # delta vs. first side
    pct_changes=(p1, p2, ...),  # delta / first
    record_counts=(c1, c2, ...),
)

All monetary fields are Decimal (ADR-0027). All dataclasses are frozen=True (ADR-0013).

The module is decoupled from the transport layer (same constraint as AnalyticsEngine): only stdlib + intra-package imports.

ComparativeAnalyticsError

Bases: AnalyticsError

Raised when a comparative analytics operation cannot be performed.

Source code in un_comtrade/analytics/compare.py
class ComparativeAnalyticsError(AnalyticsError):
    """Raised when a comparative analytics
    operation cannot be performed."""

ComparisonRow dataclass

One row of a comparative breakdown.

All numeric arrays (values, deltas, pct_changes, record_counts) are aligned by index with the comparison's labels: values[0] corresponds to the first label (comparison.labels[0]), values[1] to the second, etc.

deltas[i] is values[i] - values[0]. pct_changes[i] is (values[i] - values[0]) / values[0] * 100, or None when the baseline (values[0]) is zero (cannot divide) — callers should treat None as "undefined" rather than "no change".

Source code in un_comtrade/analytics/compare.py
@dataclass(frozen=True)
class ComparisonRow:
    """One row of a comparative breakdown.

    All numeric arrays (`values`, `deltas`,
    `pct_changes`, `record_counts`) are aligned
    by index with the comparison's labels:
    `values[0]` corresponds to the first label
    (`comparison.labels[0]`), `values[1]` to
    the second, etc.

    `deltas[i]` is `values[i] - values[0]`.
    `pct_changes[i]` is `(values[i] - values[0]) /
    values[0] * 100`, or `None` when the
    baseline (`values[0]`) is zero (cannot
    divide) — callers should treat `None` as
    "undefined" rather than "no change".
    """

    dimension_key: str
    dimension_label: str | None
    values: tuple[Decimal, ...]
    deltas: tuple[Decimal, ...]
    pct_changes: tuple[Decimal | None, ...]
    record_counts: tuple[int, ...]

    def __post_init__(self) -> None:
        for f in (
            "values", "deltas",
        ):
            arr = getattr(self, f)
            for i, v in enumerate(arr):
                if not isinstance(v, Decimal):
                    raise ComparativeAnalyticsError(
                        f"{f}[{i}] must be Decimal; "
                        f"got {type(v).__name__}"
                    )
        for i, v in enumerate(self.pct_changes):
            if v is not None and not isinstance(v, Decimal):
                raise ComparativeAnalyticsError(
                    f"pct_changes[{i}] must be Decimal or "
                    f"None; got {type(v).__name__}"
                )
        for i, v in enumerate(self.record_counts):
            if not isinstance(v, int):
                raise ComparativeAnalyticsError(
                    f"record_counts[{i}] must be int; "
                    f"got {type(v).__name__}"
                )

ComparisonSummary dataclass

Aggregate totals across all matched records (not filtered by the breakdown dimension).

Source code in un_comtrade/analytics/compare.py
@dataclass(frozen=True)
class ComparisonSummary:
    """Aggregate totals across all matched
    records (not filtered by the breakdown
    dimension)."""

    labels: tuple[str, ...]
    total_values: tuple[Decimal, ...]
    total_records: tuple[int, ...]

    def __post_init__(self) -> None:
        for i, v in enumerate(self.total_values):
            if not isinstance(v, Decimal):
                raise ComparativeAnalyticsError(
                    f"total_values[{i}] must be Decimal; "
                    f"got {type(v).__name__}"
                )

CountryComparison dataclass

Result of country_vs_country(...).

Source code in un_comtrade/analytics/compare.py
@dataclass(frozen=True)
class CountryComparison:
    """Result of `country_vs_country(...)`."""

    reporter_codes: tuple[int, ...]
    reporter_iso3: tuple[str | None, ...]
    reporter_names: tuple[str | None, ...]
    breakdown_by: str
    flow: str | None
    period: str | None
    summary: ComparisonSummary
    rows: tuple[ComparisonRow, ...]

YearComparison dataclass

Result of year_vs_year(...).

Source code in un_comtrade/analytics/compare.py
@dataclass(frozen=True)
class YearComparison:
    """Result of `year_vs_year(...)`."""

    period_a: str
    period_b: str
    reporter_code: int
    reporter_iso3: str | None
    reporter_name: str | None
    breakdown_by: str
    flow: str | None
    summary: ComparisonSummary
    rows: tuple[ComparisonRow, ...]

CommodityComparison dataclass

Result of commodity_vs_commodity(...).

Source code in un_comtrade/analytics/compare.py
@dataclass(frozen=True)
class CommodityComparison:
    """Result of `commodity_vs_commodity(...)`."""

    commodity_codes: tuple[str, ...]
    commodity_names: tuple[str | None, ...]
    reporter_code: int | None
    breakdown_by: str
    period: str | None
    summary: ComparisonSummary
    rows: tuple[ComparisonRow, ...]

PartnerComparison dataclass

Result of partner_vs_partner(...).

Source code in un_comtrade/analytics/compare.py
@dataclass(frozen=True)
class PartnerComparison:
    """Result of `partner_vs_partner(...)`."""

    partner_codes: tuple[int, ...]
    partner_iso3: tuple[str | None, ...]
    partner_names: tuple[str | None, ...]
    reporter_code: int
    breakdown_by: str
    flow: str | None
    period: str | None
    summary: ComparisonSummary
    rows: tuple[ComparisonRow, ...]

country_vs_country

country_vs_country(
    dataset: CanonicalDataset,
    *,
    reporter_codes: Sequence[int],
    breakdown_by: str = "commodity",
    flow: str | None = None,
    period: str | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> CountryComparison

Compare trade profiles of two or more reporters (countries).

Parameters

dataset The CanonicalDataset to analyse. reporter_codes Reporter codes to compare (must contain at least 2). The first entry is the baseline. breakdown_by Group-by dimension: "commodity", "partner", or "period". flow Restrict to exports ("X") or imports ("M"). When None, all flows are summed (useful for total trade volume). period Restrict to a single period (e.g. "2022"). When None, all periods are included. descending Sort rows by the last-side delta descending (True, default) or ascending. limit If supplied, return only the top limit rows.

Returns

CountryComparison Frozen dataclass with reporter metadata, aggregate ComparisonSummary, and a tuple of ComparisonRows.

Source code in un_comtrade/analytics/compare.py
def country_vs_country(
    dataset: CanonicalDataset,
    *,
    reporter_codes: Sequence[int],
    breakdown_by: str = "commodity",
    flow: str | None = None,
    period: str | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> CountryComparison:
    """Compare trade profiles of two or more
    reporters (countries).

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_codes
        Reporter codes to compare (must contain
        at least 2). The first entry is the
        baseline.
    breakdown_by
        Group-by dimension: `"commodity"`,
        `"partner"`, or `"period"`.
    flow
        Restrict to exports (`"X"`) or imports
        (`"M"`). When `None`, all flows are
        summed (useful for total trade volume).
    period
        Restrict to a single period (e.g.
        `"2022"`). When `None`, all periods are
        included.
    descending
        Sort rows by the last-side delta
        descending (`True`, default) or
        ascending.
    limit
        If supplied, return only the top `limit`
        rows.

    Returns
    -------
    CountryComparison
        Frozen dataclass with reporter metadata,
        aggregate `ComparisonSummary`, and a
        tuple of `ComparisonRow`s.
    """
    _check_canonical_dataset(dataset, fn_name="country_vs_country")
    _check_limit(limit, fn_name="country_vs_country")
    _check_breakdown_by(breakdown_by, fn_name="country_vs_country")
    _check_flow(flow, fn_name="country_vs_country")
    _check_codes(
        reporter_codes,
        fn_name="country_vs_country",
        label="reporter_codes",
    )

    # Build per-side filters and capture metadata.
    sides: list[dict[str, Any]] = []
    iso3: list[str | None] = []
    names: list[str | None] = []
    for code in reporter_codes:
        sides.append({
            "reporter_code": code,
            "flow": flow,
            "period": period,
            "__label__": str(code),
        })
        # Capture reporter metadata from first matching record.
        iso3.append(None)
        names.append(None)
        for record in dataset.records:
            if record.reporter.reporter_code == code:
                iso3[-1] = record.reporter.iso3
                names[-1] = record.reporter.name
                break

    summary, rows = _compute_rows(
        dataset,
        sides=sides,
        breakdown_by=breakdown_by,
        descending=descending,
        limit=limit,
    )
    return CountryComparison(
        reporter_codes=tuple(reporter_codes),
        reporter_iso3=tuple(iso3),
        reporter_names=tuple(names),
        breakdown_by=breakdown_by,
        flow=flow,
        period=period,
        summary=summary,
        rows=rows,
    )

year_vs_year

year_vs_year(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    period_a: str,
    period_b: str,
    breakdown_by: str = "commodity",
    flow: str | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> YearComparison

Compare the same reporter's trade between two periods.

Parameters

dataset The CanonicalDataset to analyse. reporter_code The reporter whose trade to compare. period_a Baseline period (e.g. "2020" or "202001"). period_b Comparison period. breakdown_by Group-by dimension: "commodity", "partner", or "period". flow Restrict to "X", "M", or all. descending Sort rows by delta (period_b - period_a) descending or ascending. limit If supplied, return only the top limit rows.

Returns

YearComparison Frozen dataclass with period labels, reporter metadata, ComparisonSummary, and a tuple of ComparisonRows.

Source code in un_comtrade/analytics/compare.py
def year_vs_year(
    dataset: CanonicalDataset,
    *,
    reporter_code: int,
    period_a: str,
    period_b: str,
    breakdown_by: str = "commodity",
    flow: str | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> YearComparison:
    """Compare the same reporter's trade between
    two periods.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    reporter_code
        The reporter whose trade to compare.
    period_a
        Baseline period (e.g. `"2020"` or
        `"202001"`).
    period_b
        Comparison period.
    breakdown_by
        Group-by dimension: `"commodity"`,
        `"partner"`, or `"period"`.
    flow
        Restrict to `"X"`, `"M"`, or all.
    descending
        Sort rows by delta (period_b -
        period_a) descending or ascending.
    limit
        If supplied, return only the top `limit`
        rows.

    Returns
    -------
    YearComparison
        Frozen dataclass with period labels,
        reporter metadata, `ComparisonSummary`,
        and a tuple of `ComparisonRow`s.
    """
    _check_canonical_dataset(dataset, fn_name="year_vs_year")
    _check_limit(limit, fn_name="year_vs_year")
    _check_breakdown_by(breakdown_by, fn_name="year_vs_year")
    _check_flow(flow, fn_name="year_vs_year")

    if period_a == period_b:
        raise ComparativeAnalyticsError(
            "year_vs_year requires distinct periods; "
            f"both are {period_a!r}"
        )

    sides: list[dict[str, Any]] = [
        {
            "reporter_code": reporter_code,
            "period": period_a,
            "flow": flow,
            "__label__": period_a,
        },
        {
            "reporter_code": reporter_code,
            "period": period_b,
            "flow": flow,
            "__label__": period_b,
        },
    ]

    iso3: str | None = None
    name: str | None = None
    for record in dataset.records:
        if record.reporter.reporter_code == reporter_code:
            iso3 = record.reporter.iso3
            name = record.reporter.name
            break

    summary, rows = _compute_rows(
        dataset,
        sides=sides,
        breakdown_by=breakdown_by,
        descending=descending,
        limit=limit,
    )
    return YearComparison(
        period_a=period_a,
        period_b=period_b,
        reporter_code=reporter_code,
        reporter_iso3=iso3,
        reporter_name=name,
        breakdown_by=breakdown_by,
        flow=flow,
        summary=summary,
        rows=rows,
    )

commodity_vs_commodity

commodity_vs_commodity(
    dataset: CanonicalDataset,
    *,
    commodity_codes: Sequence[str],
    reporter_code: int | None = None,
    breakdown_by: str = "partner",
    period: str | None = None,
    flow: str | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> CommodityComparison

Compare trade profiles of two or more commodities (HS codes).

Parameters

dataset The CanonicalDataset to analyse. commodity_codes HS codes to compare (must contain at least 2). The first entry is the baseline. reporter_code Restrict to a single reporter. When None (default), aggregate across all reporters. breakdown_by Group-by dimension: "commodity", "partner", or "period". Note: when grouping by "commodity", each row represents a non-compared HS code that appears in the dataset (useful as a "context" view). period Restrict to a single period. When None, all periods are included. flow Restrict to "X", "M", or all. descending Sort rows by the last-side delta descending or ascending. limit If supplied, return only the top limit rows.

Returns

CommodityComparison Frozen dataclass with commodity codes, names, optional reporter, aggregate ComparisonSummary, and ComparisonRows.

Source code in un_comtrade/analytics/compare.py
def commodity_vs_commodity(
    dataset: CanonicalDataset,
    *,
    commodity_codes: Sequence[str],
    reporter_code: int | None = None,
    breakdown_by: str = "partner",
    period: str | None = None,
    flow: str | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> CommodityComparison:
    """Compare trade profiles of two or more
    commodities (HS codes).

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    commodity_codes
        HS codes to compare (must contain at
        least 2). The first entry is the
        baseline.
    reporter_code
        Restrict to a single reporter. When
        `None` (default), aggregate across all
        reporters.
    breakdown_by
        Group-by dimension: `"commodity"`,
        `"partner"`, or `"period"`. Note: when
        grouping by `"commodity"`, each row
        represents a non-compared HS code that
        appears in the dataset (useful as a
        "context" view).
    period
        Restrict to a single period. When `None`,
        all periods are included.
    flow
        Restrict to `"X"`, `"M"`, or all.
    descending
        Sort rows by the last-side delta
        descending or ascending.
    limit
        If supplied, return only the top `limit`
        rows.

    Returns
    -------
    CommodityComparison
        Frozen dataclass with commodity codes,
        names, optional reporter, aggregate
        `ComparisonSummary`, and `ComparisonRow`s.
    """
    _check_canonical_dataset(dataset, fn_name="commodity_vs_commodity")
    _check_limit(limit, fn_name="commodity_vs_commodity")
    _check_breakdown_by(breakdown_by, fn_name="commodity_vs_commodity")
    _check_flow(flow, fn_name="commodity_vs_commodity")
    _check_codes(
        commodity_codes,
        fn_name="commodity_vs_commodity",
        label="commodity_codes",
    )

    sides: list[dict[str, Any]] = []
    names: list[str | None] = []
    for code in commodity_codes:
        sides.append({
            "commodity_code": code,
            "reporter_code": reporter_code,
            "period": period,
            "flow": flow,
            "__label__": code,
        })
        names.append(None)
        for record in dataset.records:
            if record.commodity.commodity_code == code:
                names[-1] = record.commodity.name
                break

    summary, rows = _compute_rows(
        dataset,
        sides=sides,
        breakdown_by=breakdown_by,
        descending=descending,
        limit=limit,
    )
    return CommodityComparison(
        commodity_codes=tuple(commodity_codes),
        commodity_names=tuple(names),
        reporter_code=reporter_code,
        breakdown_by=breakdown_by,
        period=period,
        summary=summary,
        rows=rows,
    )

partner_vs_partner

partner_vs_partner(
    dataset: CanonicalDataset,
    *,
    partner_codes: Sequence[int],
    reporter_code: int,
    breakdown_by: str = "commodity",
    period: str | None = None,
    flow: str | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> PartnerComparison

Compare trade profiles of two or more partners for one reporter.

Parameters

dataset The CanonicalDataset to analyse. partner_codes Partner codes to compare (must contain at least 2). The first entry is the baseline. reporter_code The reporter whose partners to compare. breakdown_by Group-by dimension: "commodity", "partner", or "period". period Restrict to a single period. When None, all periods are included. flow Restrict to "X", "M", or all. descending Sort rows by the last-side delta descending or ascending. limit If supplied, return only the top limit rows.

Returns

PartnerComparison Frozen dataclass with partner codes, ISO3, names, aggregate ComparisonSummary, and ComparisonRows.

Source code in un_comtrade/analytics/compare.py
def partner_vs_partner(
    dataset: CanonicalDataset,
    *,
    partner_codes: Sequence[int],
    reporter_code: int,
    breakdown_by: str = "commodity",
    period: str | None = None,
    flow: str | None = None,
    descending: bool = True,
    limit: int | None = None,
) -> PartnerComparison:
    """Compare trade profiles of two or more
    partners for one reporter.

    Parameters
    ----------
    dataset
        The `CanonicalDataset` to analyse.
    partner_codes
        Partner codes to compare (must contain
        at least 2). The first entry is the
        baseline.
    reporter_code
        The reporter whose partners to compare.
    breakdown_by
        Group-by dimension: `"commodity"`,
        `"partner"`, or `"period"`.
    period
        Restrict to a single period. When `None`,
        all periods are included.
    flow
        Restrict to `"X"`, `"M"`, or all.
    descending
        Sort rows by the last-side delta
        descending or ascending.
    limit
        If supplied, return only the top `limit`
        rows.

    Returns
    -------
    PartnerComparison
        Frozen dataclass with partner codes,
        ISO3, names, aggregate
        `ComparisonSummary`, and
        `ComparisonRow`s.
    """
    _check_canonical_dataset(dataset, fn_name="partner_vs_partner")
    _check_limit(limit, fn_name="partner_vs_partner")
    _check_breakdown_by(breakdown_by, fn_name="partner_vs_partner")
    _check_flow(flow, fn_name="partner_vs_partner")
    _check_codes(
        partner_codes,
        fn_name="partner_vs_partner",
        label="partner_codes",
    )

    sides: list[dict[str, Any]] = []
    iso3: list[str | None] = []
    names: list[str | None] = []
    for code in partner_codes:
        sides.append({
            "partner_code": code,
            "reporter_code": reporter_code,
            "period": period,
            "flow": flow,
            "__label__": str(code),
        })
        iso3.append(None)
        names.append(None)
        for record in dataset.records:
            if record.partner.partner_code == code:
                iso3[-1] = record.partner.iso3
                names[-1] = record.partner.name
                break

    summary, rows = _compute_rows(
        dataset,
        sides=sides,
        breakdown_by=breakdown_by,
        descending=descending,
        limit=limit,
    )
    return PartnerComparison(
        partner_codes=tuple(partner_codes),
        partner_iso3=tuple(iso3),
        partner_names=tuple(names),
        reporter_code=reporter_code,
        breakdown_by=breakdown_by,
        flow=flow,
        period=period,
        summary=summary,
        rows=rows,
    )

Examples

from un_comtrade import ComtradeClient

with ComtradeClient() as client:
    exports = client.trade.get_exports(reporter_code=699, period="2022")
    top = client.analytics.top_partners(exports, by="exports", limit=5)