diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index fd5659161b046..f4b1449281f69 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -1,11 +1,67 @@ import argparse +import copy import json +import math +from pathlib import Path +from typing import Any, List, Optional, Tuple import matplotlib.pyplot as plt import pandas as pd +## JSON parsing utils #### -def trim_string_back(string: str, width: int): + +def largest_dist_from_leaf(node: dict, depth: int = 0): + if len(node["children"]) == 0: + return depth + return max([ + largest_dist_from_leaf(child, depth=depth + 1) + for child in node["children"] + ]) + + +def get_entries_at_depth(depth: int, + entries_and_traces: List[Tuple[Any, Any]], + node: dict, + curr_depth: int = 0, + trace=()): + # assert that the query is at kernel or module level + assert depth == -1 or depth == -2 + + if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1): + # The tree is not tall enough! + entries_and_traces.append((node["entry"], trace)) + return + + if largest_dist_from_leaf(node) == (abs(depth) - 1): + entries_and_traces.append((node["entry"], trace)) + + trace = (node["entry"]["name"], ) + trace + for child in node["children"]: + get_entries_at_depth(depth, + entries_and_traces, + child, + curr_depth=curr_depth + 1, + trace=trace) + + +def fold_nodes(root: dict, nodes_to_fold: List[str]): + + stack: List[dict] = [root] + while len(stack) != 0: + node = stack.pop() + if node['entry']['name'] in nodes_to_fold: + node["children"] = [] + continue + for child in node["children"]: + stack.append(child) + return root + + +## Operation name cleanup utils #### + + +def trim_string_back(string: str, width: int) -> str: if len(string) > width: offset = len(string) - width + 3 string = string[:-offset] @@ -14,7 +70,14 @@ def trim_string_back(string: str, width: int): return string -def abbreviate_known_names(name: str): +def shorten_plot_legend_strings(legend, max_char_len: int): + for t in legend.get_texts(): + t.set_text( + trim_string_back(abbreviate_known_names(t.get_text()), + max_char_len)) + + +def abbreviate_known_names(name: str) -> str: abbreviations = { "MergedColumnParallelLinear": "MCPLinear", "QKVParallelLinear": "QKVPLinear", @@ -28,11 +91,288 @@ def abbreviate_known_names(name: str): return name -def shorten_plot_legend_strings(legend, max_char_len: int): - for t in legend.get_texts(): - t.set_text( - trim_string_back(abbreviate_known_names(t.get_text()), - max_char_len)) +def attempt_to_make_names_unique(entries_and_traces): + names, non_unique_names = (set(), set()) + + def all_the_same(items) -> bool: + return all(i == items[0] for i in items) + + for entry, _ in entries_and_traces: + if entry["name"] in names: + non_unique_names.add(entry["name"]) + else: + names.add(entry["name"]) + + for name in non_unique_names: + entries_and_traces_with_name = [(entry, trace) + for entry, trace in entries_and_traces + if entry["name"] == name] + + zipped_traces = list( + zip(*[trace for _, trace in entries_and_traces_with_name])) + first_trace_difference = next( + (i for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles)), None) + + if first_trace_difference is None: + # can't create a unique name, leave them names as the + # are they will get aggregated by the pivot_table call + continue + + for entry, trace in entries_and_traces_with_name: + entry["name"] = " <- ".join((entry["name"], ) + + trace[:first_trace_difference + 1]) + + +## Operation grouping utils #### +''' + Group operations in the given dataframe by some high-level ops like, + - gemms + - attention + - rms_norm + etc. +''' + + +def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: + + def is_rms_norm(op_name: str): + if "rms_norm_kernel" in op_name: + return True + + def is_attention_block(op_name: str): + if "flash_fwd" in op_name or \ + "reshape_and_cache_flash_kernel" in op_name: + return True + + def is_quant(op_name: str): + if "scaled_fp8_quant" in op_name or \ + "scaled_int8_quant" in op_name: + return True + + def is_gemm_op(op_name: str): + if is_quant(op_name): + return False + if "xmma_gemm" in op_name or \ + "gemv2T_kernel" in op_name or \ + "splitKreduce" in op_name or \ + "void cutlass::Kernel" in op_name or \ + "void cutlass::device_kernel" in op_name or \ + "s16816gemm" in op_name: + return True + + def is_elementwise_op(op_name: str): + return "elementwise_kernel" in op_name + + def is_mem_op(op_name: str): + return "memcpy" in op_name.lower() or \ + "memset" in op_name.lower() + + def is_vocab_embedding_op(op_name: str): + return "vocabparallelembed" in op_name.lower() + + headers = list(trace_df) + ops = copy.deepcopy(headers) + + attention_ops = list(filter(lambda x: is_attention_block(x), ops)) + ops = list(filter(lambda x: x not in attention_ops, ops)) + + quant_ops = list(filter(lambda x: is_quant(x), ops)) + ops = list(filter(lambda x: x not in quant_ops, ops)) + + gemm_ops = list(filter(lambda x: is_gemm_op(x), ops)) + ops = list(filter(lambda x: x not in gemm_ops, ops)) + + rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops)) + ops = list(filter(lambda x: x not in rms_norm_ops, ops)) + + vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops)) + ops = list(filter(lambda x: x not in vocab_embed_ops, ops)) + + mem_ops = list(filter(lambda x: is_mem_op(x), ops)) + ops = list(filter(lambda x: x not in mem_ops, ops)) + + elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops)) + ops = list(filter(lambda x: x not in elementwise_ops, ops)) + + if len(attention_ops): + trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) + if len(quant_ops): + trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + if len(gemm_ops): + trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) + if len(rms_norm_ops): + trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1) + if len(vocab_embed_ops): + trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", + axis=1) + if len(mem_ops): + trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1) + if len(elementwise_ops): + trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", + axis=1) + + trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + + vocab_embed_ops + mem_ops + elementwise_ops, + axis=1, + inplace=True) + return trace_df + + +## Data plotting utils #### + + +def plot_trace_df(traces_df: pd.DataFrame, + plot_metric: str, + plot_title: str, + output: Optional[Path] = None): + + phases = traces_df['phase'].unique() + traces_df = traces_df.pivot_table(index="phase", + columns="name", + values=plot_metric, + aggfunc="sum") + + traces_df = group_trace_by_operations(traces_df) + + # Make the figure + fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True) + + # Draw the stacked bars + ops = list(traces_df) + bottom = [0] * len(phases) + for op in ops: + values = [traces_df[op][phase] for phase in phases] + values = list(map(lambda x: 0.0 if math.isnan(x) else x, values)) + ax.bar(phases, values, label=op, bottom=bottom) + bottom = [bottom[j] + values[j] for j in range(len(phases))] + + # Write the values as text on the bars + for bar in ax.patches: + if bar.get_height() != 0: + ax.text(bar.get_x() + bar.get_width() / 2, + bar.get_height() / 2 + bar.get_y(), + f"{round(bar.get_height(), 2)}", + ha='center', + color='w', + weight='bold', + size=5) + + # Setup legend + handles, labels = plt.gca().get_legend_handles_labels() + legend = fig.legend(handles, + labels, + loc='center left', + bbox_to_anchor=(1, 1)) + shorten_plot_legend_strings(legend, 50) + + # Setup labels and title + plt.setp(ax.get_xticklabels(), rotation=90) + ax.set_ylabel(plot_metric) + plt.suptitle(plot_title) + + plt.savefig(output, bbox_inches='tight') + print("Created: ", output) + + +def main( + json_trace: Path, + output_directory: Path, + depth: int, # Fetch/Plot operations at this depth of the Json tree + plot_metric: str, + make_names_unique: bool, + top_k: int, + json_nodes_to_fold: List[str]): + + def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame: + + def get_entries_and_traces(key: str): + entries_and_traces: List[Tuple[Any, Any]] = [] + for root in profile_json[key]["summary_stats"]: + # Fold nodes in the traces as per user request. i.e. simply + # make the requested nodes leaf-nodes. + root = fold_nodes(root, json_nodes_to_fold) + get_entries_at_depth(depth, entries_and_traces, root) + return entries_and_traces + + def keep_only_top_entries(df: pd.DataFrame, + metric: str, + top_k: int = 9) -> pd.DataFrame: + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, + ["name"]] = "others" + return df + + # Get data for each key + traces = list(map(lambda x: get_entries_and_traces(x), step_keys)) + + # Attempt some cleanup + if make_names_unique: + for trace in traces: + attempt_to_make_names_unique(trace) + + # To pandas dataframe + trace_dfs = list( + map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), + traces)) + + # Respect top_k + if top_k: + trace_dfs = list( + map( + lambda trace_df: keep_only_top_entries( + trace_df, "cuda_time_us", top_k), trace_dfs)) + + # Fill in information about the step-keys + for trace_df, step_key in zip(trace_dfs, step_keys): + trace_df['phase'] = step_key + + # Combine all data frames so they can be put in a single plot + traces_df = pd.concat(trace_dfs) + + # Add a derived metric `cuda_time_ms` + traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000 + traces_df = traces_df.fillna(0) + + return traces_df + + def make_plot_title_suffix(profile_json: dict) -> str: + context = profile_json["context"] + sparsity = context.get('sparsity', None) + return (f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"OutputLen={context['output_len']}," + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}") + + profile_json = None + with open(json_trace, "r") as f: + profile_json = json.load(f) + assert profile_json is not None + + # Get all `llm.generate.step()` profile + step_traces = list(profile_json.keys()) + assert (step_traces[0] == 'context') + step_traces = step_traces[1:] # have only prefill and decodes + prefills = list(filter(lambda x: "prefill" in x, step_traces)) + all_decodes = list(filter(lambda x: "decode" in x, step_traces)) + assert len(prefills) + len(all_decodes) == len(step_traces) + assert len(prefills) == 1 + + decodes = all_decodes[::args.step_plot_interval] + if decodes[-1] != all_decodes[-1]: + # Always have the last decode + decodes.append(all_decodes[-1]) + + prefill_traces = prepare_data(profile_json, prefills) + decode_traces = prepare_data(profile_json, decodes) + + plot_title_suffix = make_plot_title_suffix(profile_json) + + plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix, + output_directory / Path("prefill.png")) + plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix, + output_directory / Path("decode_steps.png")) if __name__ == "__main__": @@ -43,30 +383,39 @@ def shorten_plot_legend_strings(legend, max_char_len: int): type=str, required=True, help="json trace file output by examples/offline_profile.py") - parser.add_argument( - "--output", - type=str, - required=False, - help="Output figure file, should be a image file such as pdf, " - "jpeg, png, etc., defaults to .pdf") + parser.add_argument("--output-directory", + type=str, + required=False, + help="Directory to output plots") parser.add_argument("--level", type=str, default="module", choices=["module", "kernel"]) - parser.add_argument("--top_k", + parser.add_argument("--top-k", type=int, - default=9, + default=12, help="Only graph the top `top_k` entries by time.") - parser.add_argument("--ignore_sampler", - action='store_true', - help="Ignore everything under the \"Sampler\" module") + parser.add_argument("--fold-json-node", + nargs='+', + default=['Sampler', 'LogitsProcessor'], + help='Do not plot the children of these nodes. Let, \ + the node represent the aggregate of all its \ + children') + parser.add_argument("--plot-metric", + type=str, + default="cuda_time_ms", + help='Metric to plot. some options are cuda_time_ms, \ + pct_cuda_time') + parser.add_argument( + "--step-plot-interval", + type=int, + default=4, + help="For every `step_plot_interval` steps, plot 1 step") args = parser.parse_args() - ignore_sampler = args.ignore_sampler + # Prepare/Extract relevant args make_names_unique = False - top_k = args.top_k - if args.level == "module": depth = -2 make_names_unique = True @@ -75,135 +424,8 @@ def shorten_plot_legend_strings(legend, max_char_len: int): else: raise Exception(f"Unexpected level value ({args.level})") - if ignore_sampler: - print("WARNING: ignoring Sampler time so the pct_cuda_time will not " - "add up to 100%") - - json_trace = args.json_trace - output = args.output if args.output else json_trace.strip(".json") + ".pdf" - - with open(json_trace, "r") as f: - profile_data = json.load(f) - - prefill_entries_and_traces = [] - decode_entries_and_traces = [] - - def largest_dist_from_leaf(node, depth=0): - if len(node["children"]) == 0: - return depth - return max([ - largest_dist_from_leaf(child, depth=depth + 1) - for child in node["children"] - ]) - - def get_entries_at_depth(depth, - entries_and_traces, - node, - curr_depth=0, - trace=()): - if ignore_sampler and node["entry"]["name"] == "Sampler": - return - - if (depth >= 0 and depth == curr_depth) or ( - depth < 0 - and largest_dist_from_leaf(node) == (abs(depth) - 1)): - entries_and_traces.append((node["entry"], trace)) - trace = (node["entry"]["name"], ) + trace - for child in node["children"]: - get_entries_at_depth(depth, - entries_and_traces, - child, - curr_depth=curr_depth + 1, - trace=trace) - - for root in profile_data["prefill"]["summary_stats"]: - get_entries_at_depth(depth, prefill_entries_and_traces, root) - for root in profile_data["decode_1"]["summary_stats"]: - get_entries_at_depth(depth, decode_entries_and_traces, root) - - def attempt_to_make_names_unique(entries_and_traces): - names, non_unique_names = (set(), set()) - - def all_the_same(items) -> bool: - return all(i == items[0] for i in items) - - for entry, _ in entries_and_traces: - if entry["name"] in names: - non_unique_names.add(entry["name"]) - else: - names.add(entry["name"]) - - for name in non_unique_names: - entries_and_traces_with_name = [ - (entry, trace) for entry, trace in entries_and_traces - if entry["name"] == name - ] - - zipped_traces = list( - zip(*[trace for _, trace in entries_and_traces_with_name])) - first_trace_difference = next( - (i for i, trace_eles in enumerate(zipped_traces) - if not all_the_same(trace_eles)), None) - - if first_trace_difference is None: - # can't create a unique name, leave them names as the - # are they will get aggregated by the pivot_table call - continue - - for entry, trace in entries_and_traces_with_name: - entry["name"] = " <- ".join((entry["name"], ) + - trace[:first_trace_difference + 1]) - - if make_names_unique: - attempt_to_make_names_unique(prefill_entries_and_traces) - attempt_to_make_names_unique(decode_entries_and_traces) - - def keep_only_top_entries(df, metric, top_k=9): - df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, - ["name"]] = "others" - - prefill_df = pd.DataFrame( - [entry for entry, _ in prefill_entries_and_traces]) - prefill_df["phase"] = "prefill" - decode_df = pd.DataFrame([entry for entry, _ in decode_entries_and_traces]) - decode_df["phase"] = "decode" - - if top_k: - keep_only_top_entries(prefill_df, "cuda_time_us", top_k) - keep_only_top_entries(decode_df, "cuda_time_us", top_k) - - df = pd.concat([prefill_df, decode_df]) - df["cuda_time_ms"] = df["cuda_time_us"] / 1000 - - fig, axes = plt.subplots(2, figsize=(5, 8), sharex=True) - - def plot_metric(metric: str, ax, add_totals=False): - pivoted_df = df.pivot_table(index="phase", - columns="name", - values=metric, - aggfunc="sum") - pivoted_df.plot.bar(stacked=True, legend=False, ax=ax) - ax.set_ylabel(metric) - - if add_totals: - ax.bar_label(ax.containers[-1]) - - plot_metric("cuda_time_ms", ax=axes[0], add_totals=True) - plot_metric("pct_cuda_time", ax=axes[1]) - - handles, labels = plt.gca().get_legend_handles_labels() - legend = fig.legend(handles, - labels, - loc='center left', - bbox_to_anchor=(0.93, 0.5)) - shorten_plot_legend_strings(legend, 50) + output_directory = args.output_directory if args.output_directory else Path( + args.json_trace).parent - context = profile_data["context"] - sparsity = context.get('sparsity', None) - plt.suptitle(f"{context['model']}\n" - f"Batch={context['batch_size']}, " - f"PromptLen={context['prompt_len']}, " - f"NumGpus={context['tensor_parallel_size']}" - f"{', Sparsity ' + sparsity if sparsity else ''}") - plt.savefig(output, bbox_inches='tight') - print("Created: ", output) + main(Path(args.json_trace), output_directory, depth, + args.plot_metric, make_names_unique, args.top_k, args.fold_json_node)