# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import argparse
from datetime import datetime
from typing import Optional

import pandas as pd

from nsys_recipe import log
from nsys_recipe.lib import helpers, pace, recipe, summary
from nsys_recipe.lib.args import ArgumentParser, Option
from nsys_recipe.lib.table_config import CompositeTable


class NvtxPace(recipe.Recipe):
    @staticmethod
    def _mapper_func(
        report_path: str, parsed_args: argparse.Namespace
    ) -> Optional[tuple[str, pd.DataFrame, pd.DataFrame, int]]:
        return pace.get_pacing_info(
            report_path, parsed_args, CompositeTable.NVTX, "text"
        )

    @log.time("Mapper")
    def mapper_func(
        self, context: recipe.Context
    ) -> list[Optional[tuple[str, pd.DataFrame, pd.DataFrame, int]]]:
        return context.wait(
            context.map(
                self._mapper_func,
                self._parsed_args.input,
                parsed_args=self._parsed_args,
            )
        )

    @log.time("Reducer")
    def reducer_func(
        self, mapper_res: list[Optional[tuple[str, pd.DataFrame, pd.DataFrame, int]]]
    ) -> None:
        filtered_res = helpers.filter_none_or_empty(mapper_res)
        filtered_res = sorted(filtered_res, key=lambda x: x[0])

        filenames, pace_dfs, stats_dfs, session_starts = zip(*filtered_res)
        pace.apply_time_offset(session_starts, pace_dfs)

        files_df = pd.DataFrame({"File": filenames}).rename_axis("Rank")
        files_df.to_parquet(self.add_output_file("files.parquet"))

        stats_df = pd.concat(stats_dfs)
        stats_df = summary.aggregate_stats_df(stats_df)
        stats_df.to_parquet(self.add_output_file("stats.parquet"))

        for name, df in pace.split_columns_as_dataframes(pace_dfs).items():
            df.to_parquet(self.add_output_file(f"pace_{name}.parquet"), index=False)

    def save_notebook(self) -> None:
        self.create_notebook("pace.ipynb")
        self.add_notebook_helper_file("nsys_pres.py")

    def save_analysis_file(self) -> None:
        self._analysis_dict.update(
            {
                "EndTime": str(datetime.now()),
                "Outputs": self._output_files,
            }
        )
        self.create_analysis_file()

    def run(self, context: recipe.Context) -> None:
        super().run(context)

        mapper_res = self.mapper_func(context)
        self.reducer_func(mapper_res)

        self.save_notebook()
        self.save_analysis_file()

    @classmethod
    def get_argument_parser(cls) -> ArgumentParser:
        parser = super().get_argument_parser()

        parser.add_recipe_argument(Option.INPUT, required=True)
        parser.add_recipe_argument(
            "--name",
            type=str,
            help="Name of the NVTX range used as delineator between iterations",
            required=True,
        )
        parser.add_recipe_argument(Option.START)
        parser.add_recipe_argument(Option.END)

        filter_group = parser.recipe_group.add_mutually_exclusive_group()
        parser.add_argument_to_group(filter_group, Option.FILTER_TIME)
        parser.add_argument_to_group(filter_group, Option.FILTER_NVTX)

        return parser
